nova commited on
Commit
fa15145
·
verified ·
1 Parent(s): af61ec2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -0
app.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
4
+ from threading import Thread
5
+ # Model Configuration
6
+ MODEL_ID = "microsoft/Phi-3.5-mini-instruct"
7
+ # Check GPU
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+ print(f"🚀 Loading {MODEL_ID} on {device}...")
10
+ try:
11
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
12
+ model = AutoModelForCausalLM.from_pretrained(
13
+ MODEL_ID,
14
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
15
+ trust_remote_code=True,
16
+ device_map="auto"
17
+ )
18
+ except Exception as e:
19
+ print(f"❌ Error loading model: {e}")
20
+ # Fallback/Exit logic
21
+ def format_prompt(message, history, system_prompt):
22
+ # Phi-3 Format
23
+ # <|system|>\n...<|end|>\n<|user|>\n...<|end|>\n<|assistant|>\n
24
+
25
+ prompt = f"<|system|>\n{system_prompt}<|end|>\n"
26
+ for user_msg, bot_msg in history:
27
+ prompt += f"<|user|>\n{user_msg}<|end|>\n<|assistant|>\n{bot_msg}<|end|>\n"
28
+ prompt += f"<|user|>\n{message}<|end|>\n<|assistant|>\n"
29
+ return prompt
30
+ def chat(message, history):
31
+ # Default System Prompt for Lumin
32
+ SYSTEM_PROMPT = "You are Lumin Flash, a helpful and efficient AI assistant."
33
+
34
+ # 1. Format Input
35
+ prompt_text = format_prompt(message, history, SYSTEM_PROMPT)
36
+ inputs = tokenizer(prompt_text, return_tensors="pt").to(device)
37
+ # 2. Streamer
38
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
39
+
40
+ # 3. Generate
41
+ generation_kwargs = dict(
42
+ inputs,
43
+ streamer=streamer,
44
+ max_new_tokens=1024,
45
+ temperature=0.7,
46
+ do_sample=True
47
+ )
48
+
49
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
50
+ thread.start()
51
+ # 4. Yield Output
52
+ partial_text = ""
53
+ for new_text in streamer:
54
+ partial_text += new_text
55
+ yield partial_text
56
+ # Gradio Interface
57
+ demo = gr.ChatInterface(
58
+ fn=chat,
59
+ chatbot=gr.Chatbot(height=600),
60
+ textbox=gr.Textbox(placeholder="Ask Lumin Flash...", container=False, scale=7),
61
+ title="Lumin Flash (Phi-3.5)",
62
+ theme="soft",
63
+ retry_btn=None,
64
+ undo_btn=None,
65
+ clear_btn="Clear",
66
+ )
67
+ if __name__ == "__main__":
68
+ demo.queue().launch(server_name="0.0.0.0", server_port=7860)