Wenye He commited on
Commit
5276429
·
verified ·
1 Parent(s): 8ff4f53

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -31
app.py CHANGED
@@ -1,35 +1,97 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
4
-
5
- # Use Phi model (ensure to pass trust_remote_code if required)
6
- model_name = "microsoft/Phi-3-mini-4k-instruct"
7
- model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
8
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
9
-
10
- def generate_response_phi(user_input, chat_history):
11
- if chat_history is None:
12
- chat_history = []
13
- # Append user message to the conversation as a dict (the Phi template expects this format)
14
- chat_history.append({"role": "user", "content": user_input})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- # Use the tokenizer's chat template to prepare inputs
17
- inputs = tokenizer.apply_chat_template(
18
- chat_history, add_generation_prompt=True, return_tensors="pt"
19
- )
20
- # Generate response
21
- output_ids = model.generate(**inputs, max_new_tokens=100)
22
- generated_text = tokenizer.batch_decode(output_ids)[0]
23
- # Extract assistant reply (assuming the template adds "<|assistant|>" marker)
24
- answer = generated_text.split("<|assistant|>")[-1].strip()
25
- chat_history.append({"role": "assistant", "content": answer})
26
- return "", chat_history
27
-
28
- with gr.Blocks() as phi_demo:
29
- gr.Markdown("# Phi Chatbot")
30
- chatbot = gr.Chatbot()
31
- state = gr.State([])
32
- txt = gr.Textbox(placeholder="Enter your message")
33
- txt.submit(generate_response_phi, [txt, state], [txt, chatbot])
34
 
35
- phi_demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
4
+
5
+ # Model configurations
6
+ MODEL_CONFIG = {
7
+ "llama": {
8
+ "model_name": "meta-llama/Llama-2-7b-chat-hf",
9
+ "template": "[INST] {message} [/INST]"
10
+ },
11
+ "phi": {
12
+ "model_name": "microsoft/phi-2",
13
+ "template": "{message}"
14
+ }
15
+ }
16
+
17
+ class ChatModel:
18
+ def __init__(self):
19
+ self.model = None
20
+ self.tokenizer = None
21
+ self.current_model = None
22
+
23
+ def load_model(self, model_name):
24
+ if model_name != self.current_model:
25
+ config = MODEL_CONFIG[model_name]
26
+ self.tokenizer = AutoTokenizer.from_pretrained(config["model_name"])
27
+ self.model = AutoModelForCausalLM.from_pretrained(
28
+ config["model_name"],
29
+ torch_dtype=torch.float16,
30
+ device_map="auto"
31
+ )
32
+ self.current_model = model_name
33
+
34
+ def format_message(self, message, model_name):
35
+ return MODEL_CONFIG[model_name]["template"].format(message=message)
36
+
37
+ def generate(self, message, model_name, history):
38
+ self.load_model(model_name)
39
+ formatted_message = self.format_message(message, model_name)
40
+
41
+ # Create pipeline for text generation
42
+ pipe = pipeline(
43
+ "text-generation",
44
+ model=self.model,
45
+ tokenizer=self.tokenizer,
46
+ device_map="auto"
47
+ )
48
+
49
+ # Generate response
50
+ response = pipe(
51
+ formatted_message,
52
+ max_length=200,
53
+ do_sample=True,
54
+ temperature=0.7,
55
+ top_k=50,
56
+ top_p=0.95,
57
+ pad_token_id=self.tokenizer.eos_token_id
58
+ )
59
+
60
+ return response[0]['generated_text'].replace(formatted_message, "").strip()
61
+
62
+ # Initialize model handler
63
+ model_handler = ChatModel()
64
+
65
+ def chat(message, history, model_choice):
66
+ response = model_handler.generate(message, model_choice, history)
67
+ return [(message, response)]
68
+
69
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
70
+ gr.Markdown("# 🤖 Local LLM Chatbot\nSelect a model and start chatting!")
71
 
72
+ with gr.Row():
73
+ model_choice = gr.Dropdown(
74
+ choices=["llama", "phi"],
75
+ label="Select Model",
76
+ value="phi"
77
+ )
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
+ chatbot = gr.Chatbot(height=400)
80
+ msg = gr.Textbox(label="Your Message", placeholder="Type your message here...")
81
+
82
+ with gr.Row():
83
+ submit_btn = gr.Button("Send")
84
+ clear_btn = gr.ClearButton([msg, chatbot])
85
+
86
+ msg.submit(
87
+ fn=chat,
88
+ inputs=[msg, chatbot, model_choice],
89
+ outputs=[chatbot]
90
+ )
91
+ submit_btn.click(
92
+ fn=chat,
93
+ inputs=[msg, chatbot, model_choice],
94
+ outputs=[chatbot]
95
+ )
96
+
97
+ demo.launch()