Wenye He commited on
Commit
dd93054
·
verified ·
1 Parent(s): 7a75e11

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -58
app.py CHANGED
@@ -1,97 +1,100 @@
1
  import gradio as gr
 
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
4
 
5
- # Model configurations
6
  MODEL_CONFIG = {
7
- "llama": {
8
- "model_name": "meta-llama/Llama-2-7b-chat-hf",
9
- "template": "[INST] {message} [/INST]"
10
  },
11
- "phi": {
12
- "model_name": "microsoft/phi-2",
13
- "template": "{message}"
 
 
 
 
14
  }
15
  }
16
 
 
 
 
 
 
 
 
 
17
  class ChatModel:
18
  def __init__(self):
19
- self.model = None
20
- self.tokenizer = None
21
- self.current_model = None
22
-
23
  def load_model(self, model_name):
24
- if model_name != self.current_model:
25
  config = MODEL_CONFIG[model_name]
26
- self.tokenizer = AutoTokenizer.from_pretrained(config["model_name"])
27
- self.model = AutoModelForCausalLM.from_pretrained(
 
 
 
28
  config["model_name"],
29
- torch_dtype=torch.float16,
30
- device_map="auto"
 
 
31
  )
32
- self.current_model = model_name
33
-
34
- def format_message(self, message, model_name):
35
- return MODEL_CONFIG[model_name]["template"].format(message=message)
36
 
37
  def generate(self, message, model_name, history):
38
  self.load_model(model_name)
39
- formatted_message = self.format_message(message, model_name)
40
 
41
- # Create pipeline for text generation
 
 
 
42
  pipe = pipeline(
43
  "text-generation",
44
- model=self.model,
45
- tokenizer=self.tokenizer,
46
- device_map="auto"
47
- )
48
-
49
- # Generate response
50
- response = pipe(
51
- formatted_message,
52
- max_length=200,
53
- do_sample=True,
54
  temperature=0.7,
55
- top_k=50,
56
- top_p=0.95,
57
- pad_token_id=self.tokenizer.eos_token_id
 
58
  )
59
 
60
- return response[0]['generated_text'].replace(formatted_message, "").strip()
 
61
 
62
- # Initialize model handler
63
  model_handler = ChatModel()
64
 
65
  def chat(message, history, model_choice):
66
- response = model_handler.generate(message, model_choice, history)
67
- return [(message, response)]
 
 
 
68
 
69
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
70
- gr.Markdown("# 🤖 Local LLM Chatbot\nSelect a model and start chatting!")
71
-
72
  with gr.Row():
73
  model_choice = gr.Dropdown(
74
- choices=["llama", "phi"],
75
  label="Select Model",
76
- value="phi"
77
  )
78
-
79
  chatbot = gr.Chatbot(height=400)
80
- msg = gr.Textbox(label="Your Message", placeholder="Type your message here...")
81
-
82
  with gr.Row():
83
- submit_btn = gr.Button("Send")
84
  clear_btn = gr.ClearButton([msg, chatbot])
85
 
86
- msg.submit(
87
- fn=chat,
88
- inputs=[msg, chatbot, model_choice],
89
- outputs=[chatbot]
90
- )
91
- submit_btn.click(
92
- fn=chat,
93
- inputs=[msg, chatbot, model_choice],
94
- outputs=[chatbot]
95
- )
96
 
97
  demo.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
3
  import torch
 
4
 
 
5
  MODEL_CONFIG = {
6
+ "phi-3": {
7
+ "model_name": "microsoft/phi-3-mini-4k-instruct",
8
+ "template": "<|user|>\n{message}<|end|>\n<|assistant|>"
9
  },
10
+ "llama3-8b": {
11
+ "model_name": "meta-llama/Meta-Llama-3-8B-Instruct",
12
+ "template": """<|begin_of_text|><|start_header_id|>user<|end_header_id|>
13
+
14
+ {message}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
15
+
16
+ """
17
  }
18
  }
19
 
20
+ # Quantization config for 4-bit loading
21
+ bnb_config = BitsAndBytesConfig(
22
+ load_in_4bit=True,
23
+ bnb_4bit_quant_type="nf4",
24
+ bnb_4bit_compute_dtype=torch.float16,
25
+ bnb_4bit_use_double_quant=True
26
+ )
27
+
28
  class ChatModel:
29
  def __init__(self):
30
+ self.models = {}
31
+ self.tokenizers = {}
32
+
 
33
  def load_model(self, model_name):
34
+ if model_name not in self.models:
35
  config = MODEL_CONFIG[model_name]
36
+
37
+ tokenizer = AutoTokenizer.from_pretrained(config["model_name"])
38
+ tokenizer.pad_token = tokenizer.eos_token
39
+
40
+ model = AutoModelForCausalLM.from_pretrained(
41
  config["model_name"],
42
+ quantization_config=bnb_config,
43
+ device_map="auto",
44
+ attn_implementation="flash_attention_2" if "phi-3" in model_name else None,
45
+ torch_dtype=torch.float16
46
  )
47
+
48
+ self.models[model_name] = model
49
+ self.tokenizers[model_name] = tokenizer
 
50
 
51
  def generate(self, message, model_name, history):
52
  self.load_model(model_name)
53
+ config = MODEL_CONFIG[model_name]
54
 
55
+ # Format prompt
56
+ prompt = config["template"].format(message=message)
57
+
58
+ # Create pipeline
59
  pipe = pipeline(
60
  "text-generation",
61
+ model=self.models[model_name],
62
+ tokenizer=self.tokenizers[model_name],
63
+ max_new_tokens=512,
 
 
 
 
 
 
 
64
  temperature=0.7,
65
+ top_p=0.9,
66
+ repetition_penalty=1.1,
67
+ do_sample=True,
68
+ return_full_text=False
69
  )
70
 
71
+ response = pipe(prompt)[0]['generated_text']
72
+ return response.strip()
73
 
 
74
  model_handler = ChatModel()
75
 
76
  def chat(message, history, model_choice):
77
+ try:
78
+ response = model_handler.generate(message, model_choice, history)
79
+ return [(message, response)]
80
+ except Exception as e:
81
+ return [(message, f"Error: {str(e)}")]
82
 
83
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
84
+ gr.Markdown("# 🚀 Phi-3 vs Llama-3 Chatbot")
 
85
  with gr.Row():
86
  model_choice = gr.Dropdown(
87
+ choices=["phi-3", "llama3-8b"],
88
  label="Select Model",
89
+ value="phi-3"
90
  )
 
91
  chatbot = gr.Chatbot(height=400)
92
+ msg = gr.Textbox(label="Message", placeholder="Type here...")
 
93
  with gr.Row():
94
+ submit_btn = gr.Button("Send", variant="primary")
95
  clear_btn = gr.ClearButton([msg, chatbot])
96
 
97
+ msg.submit(chat, [msg, chatbot, model_choice], chatbot)
98
+ submit_btn.click(chat, [msg, chatbot, model_choice], chatbot)
 
 
 
 
 
 
 
 
99
 
100
  demo.launch()