File size: 3,778 Bytes
47d3bb2
 
 
ddcd52b
47d3bb2
 
 
ddcd52b
47d3bb2
 
 
 
 
 
 
ddcd52b
47d3bb2
 
 
 
 
 
 
 
 
 
 
 
 
 
e922296
47d3bb2
 
 
 
 
 
 
 
 
 
 
ddcd52b
47d3bb2
ddcd52b
47d3bb2
 
ddcd52b
47d3bb2
 
e922296
47d3bb2
 
 
 
 
 
 
 
 
 
 
ddcd52b
47d3bb2
 
ddcd52b
47d3bb2
 
ddcd52b
28bae52
 
47d3bb2
ddcd52b
e922296
 
 
28bae52
ddcd52b
 
 
 
 
 
 
 
 
 
 
 
47d3bb2
 
 
 
 
 
 
 
 
 
 
 
 
 
e24b08e
47d3bb2
 
 
e922296
47d3bb2
 
28bae52
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import gradio as gr
import torch
import json
import re
from tokenizers import Tokenizer
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from ModelArchitecture import Transformer, ModelConfig, generate

# -----------------------------
# Load model and tokenizer
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
REPO_ID = "VirtualInsight/Lumen-Instruct"

# Download model files
model_path = hf_hub_download(repo_id=REPO_ID, filename="model.safetensors")
tokenizer_path = hf_hub_download(repo_id=REPO_ID, filename="tokenizer.json")
config_path = hf_hub_download(repo_id=REPO_ID, filename="config.json")

# Initialize tokenizer and model
tokenizer = Tokenizer.from_file(tokenizer_path)
with open(config_path) as f:
    config = ModelConfig(**json.load(f))

model = Transformer(config).to(device)
model.load_state_dict(load_file(model_path, device=str(device)), strict=False)
model.eval()

# -----------------------------
# Special Tokens
# -----------------------------
EOS_TOKEN = "<|im_end|>"
EOS_TOKEN_ID = tokenizer.encode(EOS_TOKEN).ids[0]
print(f"EOS token ID: {EOS_TOKEN_ID}")

# -----------------------------
# Generation Function
# -----------------------------
@torch.no_grad()
def generate_response(prompt, max_tokens=200, temperature=0.7, top_p=0.9):
    """
    Generates a clean assistant-only response, removing any echoed user text.
    """
    # Chat-style prompt
    formatted_prompt = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"

    # Tokenize
    input_ids = torch.tensor([tokenizer.encode(formatted_prompt).ids], dtype=torch.long, device=device)

    # Generate
    output = generate(
        model,
        input_ids,
        max_new_tokens=max_tokens,
        temperature=temperature,
        top_k=50,
        top_p=top_p,
        do_sample=True,
        eos_token_id=EOS_TOKEN_ID,
    )

    # Decode
    full_text = tokenizer.decode(output[0].tolist())

    # Extract assistant’s section
    if "<|im_start|>assistant" in full_text:
        response = full_text.split("<|im_start|>assistant")[-1]
        response = response.split("<|im_end|>")[0] if "<|im_end|>" in response else response
    else:
        response = full_text

    # Remove leftover role tokens and whitespace
    response = re.sub(r"(?i)\buser\b.*", "", response)
    response = re.sub(r"(?i)\bassistant\b.*", "", response)
    response = response.strip()

    # 🧹 Final cleanup: remove leading user echo if present
    lines = [line.strip() for line in response.splitlines() if line.strip()]
    if len(lines) >= 2 and (
        lines[0].lower() == prompt.strip().lower()  # exact echo
        or lines[0].rstrip("!?.,").lower() == prompt.strip().rstrip("!?.,").lower()  # punctuation variation
        or len(lines[0].split()) <= 3  # very short echo like "Hello!"
    ):
        lines = lines[1:]  # drop the first echo line

    clean_response = "\n".join(lines).strip()

    return clean_response

# -----------------------------
# Gradio Interface
# -----------------------------
demo = gr.Interface(
    fn=generate_response,
    inputs=[
        gr.Textbox(label="User Prompt", placeholder="Ask Lumen anything...", lines=3),
        gr.Slider(10, 500, value=200, label="Max Tokens"),
        gr.Slider(0.1, 2.0, value=0.7, label="Temperature"),
        gr.Slider(0.1, 1.0, value=0.9, label="Top-p"),
    ],
    outputs=gr.Textbox(label="Lumen’s Response", lines=10),
    title="Lumen Instruct Model",
    description="Lumen Instruct — a fine-tuned, instruction-following language model built on the Lumen Foundational Model.",
)

# -----------------------------
# Launch
# -----------------------------
if __name__ == "__main__":
    demo.launch(share=True)