Wenye He commited on
Commit
97128d6
·
verified ·
1 Parent(s): cfb24bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -52
app.py CHANGED
@@ -8,7 +8,7 @@ MODEL_CONFIG = {
8
  "template": "<|user|>\n{message}<|end|>\n<|assistant|>"
9
  },
10
  "llama3-8b": {
11
- "model_name": "meta-llama/Meta-Llama-3-8B-Instruct",
12
  "template": """<|begin_of_text|><|start_header_id|>user<|end_header_id|>
13
 
14
  {message}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
@@ -17,7 +17,6 @@ MODEL_CONFIG = {
17
  }
18
  }
19
 
20
- # Quantization config (4-bit)
21
  bnb_config = BitsAndBytesConfig(
22
  load_in_4bit=True,
23
  bnb_4bit_quant_type="nf4",
@@ -42,59 +41,12 @@ class ChatModel:
42
  quantization_config=bnb_config,
43
  device_map="auto",
44
  torch_dtype=torch.float16,
45
- low_cpu_mem_usage=True
46
  )
47
 
48
  self.models[model_name] = model
49
  self.tokenizers[model_name] = tokenizer
50
 
51
- def generate(self, message, model_name, history):
52
- self.load_model(model_name)
53
- config = MODEL_CONFIG[model_name]
54
-
55
- # Format prompt
56
- prompt = config["template"].format(message=message)
57
-
58
- # Create pipeline
59
- pipe = pipeline(
60
- "text-generation",
61
- model=self.models[model_name],
62
- tokenizer=self.tokenizers[model_name],
63
- max_new_tokens=384,
64
- temperature=0.7,
65
- top_p=0.9,
66
- repetition_penalty=1.1,
67
- do_sample=True,
68
- return_full_text=False
69
- )
70
-
71
- response = pipe(prompt)[0]['generated_text']
72
- return response.strip()
73
-
74
- model_handler = ChatModel()
75
-
76
- def chat(message, history, model_choice):
77
- try:
78
- response = model_handler.generate(message, model_choice, history)
79
- return [(message, response)]
80
- except Exception as e:
81
- return [(message, f"Error: {str(e)}")]
82
-
83
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
84
- gr.Markdown("# 🚀 Phi-3 vs Llama-3 Chatbot")
85
- with gr.Row():
86
- model_choice = gr.Dropdown(
87
- choices=["phi-3", "llama3-8b"],
88
- label="Select Model",
89
- value="phi-3"
90
- )
91
- chatbot = gr.Chatbot(height=400)
92
- msg = gr.Textbox(label="Message", placeholder="Type here...")
93
- with gr.Row():
94
- submit_btn = gr.Button("Send", variant="primary")
95
- clear_btn = gr.ClearButton([msg, chatbot])
96
-
97
- msg.submit(chat, [msg, chatbot, model_choice], chatbot)
98
- submit_btn.click(chat, [msg, chatbot, model_choice], chatbot)
99
 
100
- demo.launch()
 
8
  "template": "<|user|>\n{message}<|end|>\n<|assistant|>"
9
  },
10
  "llama3-8b": {
11
+ "model_name": "NousResearch/Meta-Llama-3-8B-Instruct",
12
  "template": """<|begin_of_text|><|start_header_id|>user<|end_header_id|>
13
 
14
  {message}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
 
17
  }
18
  }
19
 
 
20
  bnb_config = BitsAndBytesConfig(
21
  load_in_4bit=True,
22
  bnb_4bit_quant_type="nf4",
 
41
  quantization_config=bnb_config,
42
  device_map="auto",
43
  torch_dtype=torch.float16,
44
+ trust_remote_code=True
45
  )
46
 
47
  self.models[model_name] = model
48
  self.tokenizers[model_name] = tokenizer
49
 
50
+ # ... (keep the rest of the code the same as previous version)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ # ... (remaining code identical to previous implementation)