Lumen-Instruct / app.py
VirtualInsight's picture
Update app.py
e24b08e verified
import gradio as gr
import torch
import json
import re
from tokenizers import Tokenizer
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from ModelArchitecture import Transformer, ModelConfig, generate
# -----------------------------
# Load model and tokenizer
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
REPO_ID = "VirtualInsight/Lumen-Instruct"
# Download model files
model_path = hf_hub_download(repo_id=REPO_ID, filename="model.safetensors")
tokenizer_path = hf_hub_download(repo_id=REPO_ID, filename="tokenizer.json")
config_path = hf_hub_download(repo_id=REPO_ID, filename="config.json")
# Initialize tokenizer and model
tokenizer = Tokenizer.from_file(tokenizer_path)
with open(config_path) as f:
config = ModelConfig(**json.load(f))
model = Transformer(config).to(device)
model.load_state_dict(load_file(model_path, device=str(device)), strict=False)
model.eval()
# -----------------------------
# Special Tokens
# -----------------------------
EOS_TOKEN = "<|im_end|>"
EOS_TOKEN_ID = tokenizer.encode(EOS_TOKEN).ids[0]
print(f"EOS token ID: {EOS_TOKEN_ID}")
# -----------------------------
# Generation Function
# -----------------------------
@torch.no_grad()
def generate_response(prompt, max_tokens=200, temperature=0.7, top_p=0.9):
"""
Generates a clean assistant-only response, removing any echoed user text.
"""
# Chat-style prompt
formatted_prompt = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
# Tokenize
input_ids = torch.tensor([tokenizer.encode(formatted_prompt).ids], dtype=torch.long, device=device)
# Generate
output = generate(
model,
input_ids,
max_new_tokens=max_tokens,
temperature=temperature,
top_k=50,
top_p=top_p,
do_sample=True,
eos_token_id=EOS_TOKEN_ID,
)
# Decode
full_text = tokenizer.decode(output[0].tolist())
# Extract assistant’s section
if "<|im_start|>assistant" in full_text:
response = full_text.split("<|im_start|>assistant")[-1]
response = response.split("<|im_end|>")[0] if "<|im_end|>" in response else response
else:
response = full_text
# Remove leftover role tokens and whitespace
response = re.sub(r"(?i)\buser\b.*", "", response)
response = re.sub(r"(?i)\bassistant\b.*", "", response)
response = response.strip()
# 🧹 Final cleanup: remove leading user echo if present
lines = [line.strip() for line in response.splitlines() if line.strip()]
if len(lines) >= 2 and (
lines[0].lower() == prompt.strip().lower() # exact echo
or lines[0].rstrip("!?.,").lower() == prompt.strip().rstrip("!?.,").lower() # punctuation variation
or len(lines[0].split()) <= 3 # very short echo like "Hello!"
):
lines = lines[1:] # drop the first echo line
clean_response = "\n".join(lines).strip()
return clean_response
# -----------------------------
# Gradio Interface
# -----------------------------
demo = gr.Interface(
fn=generate_response,
inputs=[
gr.Textbox(label="User Prompt", placeholder="Ask Lumen anything...", lines=3),
gr.Slider(10, 500, value=200, label="Max Tokens"),
gr.Slider(0.1, 2.0, value=0.7, label="Temperature"),
gr.Slider(0.1, 1.0, value=0.9, label="Top-p"),
],
outputs=gr.Textbox(label="Lumen’s Response", lines=10),
title="Lumen Instruct Model",
description="Lumen Instruct — a fine-tuned, instruction-following language model built on the Lumen Foundational Model.",
)
# -----------------------------
# Launch
# -----------------------------
if __name__ == "__main__":
demo.launch(share=True)