VirtualInsight commited on
Commit
47d3bb2
·
verified ·
1 Parent(s): fff1dce

Create app.py

Browse files

Inference Implementation

Files changed (1) hide show
  1. app.py +94 -0
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import json
4
+ from tokenizers import Tokenizer
5
+ from huggingface_hub import hf_hub_download
6
+ from ModelArchitecture import Transformer, ModelConfig, generate
7
+ from safetensors.torch import load_file
8
+
9
+ # -----------------------------
10
+ # Load model and tokenizer
11
+ # -----------------------------
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ REPO_ID = "VirtualInsight/Lumen-Instruct"
14
+
15
+ # Download model assets from Hugging Face Hub
16
+ model_path = hf_hub_download(repo_id=REPO_ID, filename="model.safetensors")
17
+ tokenizer_path = hf_hub_download(repo_id=REPO_ID, filename="tokenizer.json")
18
+ config_path = hf_hub_download(repo_id=REPO_ID, filename="config.json")
19
+
20
+ # Initialize tokenizer and model
21
+ tokenizer = Tokenizer.from_file(tokenizer_path)
22
+ with open(config_path) as f:
23
+ config = ModelConfig(**json.load(f))
24
+
25
+ model = Transformer(config).to(device)
26
+ model.load_state_dict(load_file(model_path, device=str(device)), strict=False)
27
+ model.eval()
28
+
29
+ # -----------------------------
30
+ # Special Tokens for Chat Format
31
+ # -----------------------------
32
+ EOS_TOKEN = "<|im_end|>"
33
+ EOS_TOKEN_ID = tokenizer.encode(EOS_TOKEN).ids[0]
34
+ print(f"EOS token ID: {EOS_TOKEN_ID}")
35
+
36
+ # -----------------------------
37
+ # Generation Function
38
+ # -----------------------------
39
+ @torch.no_grad()
40
+ def generate_response(prompt, max_tokens=200, temperature=0.7, top_p=0.9):
41
+ """
42
+ Generates a chat-style response using the Lumen-Instruct model.
43
+ """
44
+ # Format the input as a structured conversation
45
+ formatted_prompt = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
46
+
47
+ # Tokenize input
48
+ input_ids = torch.tensor([tokenizer.encode(formatted_prompt).ids], dtype=torch.long, device=device)
49
+
50
+ # Generate response with sampling
51
+ output = generate(
52
+ model,
53
+ input_ids,
54
+ max_new_tokens=max_tokens,
55
+ temperature=temperature,
56
+ top_k=50,
57
+ top_p=top_p,
58
+ do_sample=True,
59
+ eos_token_id=EOS_TOKEN_ID,
60
+ )
61
+
62
+ # Decode full output text
63
+ full_text = tokenizer.decode(output[0].tolist())
64
+
65
+ # Extract only assistant’s part
66
+ if "<|im_start|>assistant" in full_text:
67
+ response = full_text.split("<|im_start|>assistant")[-1]
68
+ if "<|im_end|>" in response:
69
+ response = response.split("<|im_end|>")[0]
70
+ return response.strip()
71
+
72
+ return full_text.strip()
73
+
74
+ # -----------------------------
75
+ # Gradio Interface
76
+ # -----------------------------
77
+ demo = gr.Interface(
78
+ fn=generate_response,
79
+ inputs=[
80
+ gr.Textbox(label="User Prompt", placeholder="Ask Lumen anything...", lines=3),
81
+ gr.Slider(10, 500, value=200, label="Max Tokens"),
82
+ gr.Slider(0.1, 2.0, value=0.7, label="Temperature"),
83
+ gr.Slider(0.1, 1.0, value=0.9, label="Top-p"),
84
+ ],
85
+ outputs=gr.Textbox(label="Lumen’s Response", lines=10),
86
+ title="Lumen Instruct Model",
87
+ description="Chat with Lumen — a fine-tuned instruction-following language model created by Hariom Jangra.",
88
+ )
89
+
90
+ # -----------------------------
91
+ # Launch
92
+ # -----------------------------
93
+ if __name__ == "__main__":
94
+ demo.launch()