Wenye He commited on
Commit
dd8d3db
·
verified ·
1 Parent(s): 3a4c40c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -44
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
3
  import torch
4
- import time
5
 
6
  MODEL_CONFIG = {
7
  "phi-3": {
@@ -49,64 +48,39 @@ class ChatModel:
49
  self.tokenizers[model_name] = tokenizer
50
 
51
  def generate(self, message, model_name, history):
52
- start_time = time.time()
53
  self.load_model(model_name)
54
  config = MODEL_CONFIG[model_name]
55
 
56
  # Format prompt
57
  prompt = config["template"].format(message=message)
58
 
59
- # Tokenize input with proper max_length handling
60
- inputs = self.tokenizers[model_name](
61
- prompt,
62
- return_tensors="pt",
63
- max_length=2048,
64
- truncation=True
65
- ).to("cuda")
66
-
67
- # Generation parameters
68
- generation_kwargs = {
69
- "inputs": inputs.input_ids,
70
- "max_new_tokens": 384,
71
- "temperature": 0.7,
72
- "top_p": 0.9,
73
- "do_sample": True,
74
- "pad_token_id": self.tokenizers[model_name].eos_token_id
75
- }
76
-
77
- # Phi-3 specific workaround
78
- if "phi-3" in model_name:
79
- generation_kwargs["attention_mask"] = inputs.attention_mask
80
- generation_kwargs.pop("inputs")
81
- generation_kwargs["input_ids"] = inputs.input_ids
82
-
83
- outputs = self.models[model_name].generate(**generation_kwargs)
84
-
85
- # Decode response
86
- response = self.tokenizers[model_name].decode(
87
- outputs[0][inputs.input_ids.shape[-1]:],
88
- skip_special_tokens=True
89
- ).strip()
90
-
91
- # Calculate metrics
92
- elapsed_time = time.time() - start_time
93
- tokens = outputs[0].shape[-1] - inputs.input_ids.shape[-1]
94
- tokens_per_sec = tokens / elapsed_time if elapsed_time > 0 else 0
95
 
96
- return response, elapsed_time, tokens_per_sec
 
97
 
98
  model_handler = ChatModel()
99
 
100
  def chat(message, history, model_choice):
101
  try:
102
- response, response_time, token_speed = model_handler.generate(message, model_choice, history)
103
- formatted_response = f"{response}\n\n⏱️ Response Time: {response_time:.2f}s | 🚀 Speed: {token_speed:.2f} tokens/s"
104
- return [(message, formatted_response)]
105
  except Exception as e:
106
  return [(message, f"Error: {str(e)}")]
107
 
108
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
109
- gr.Markdown("# 🚀 LLM Chatbot with Performance Metrics")
110
  with gr.Row():
111
  model_choice = gr.Dropdown(
112
  choices=["phi-3", "llama3-8b"],
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
3
  import torch
 
4
 
5
  MODEL_CONFIG = {
6
  "phi-3": {
 
48
  self.tokenizers[model_name] = tokenizer
49
 
50
  def generate(self, message, model_name, history):
 
51
  self.load_model(model_name)
52
  config = MODEL_CONFIG[model_name]
53
 
54
  # Format prompt
55
  prompt = config["template"].format(message=message)
56
 
57
+ # Create pipeline
58
+ pipe = pipeline(
59
+ "text-generation",
60
+ model=self.models[model_name],
61
+ tokenizer=self.tokenizers[model_name],
62
+ max_new_tokens=384,
63
+ temperature=0.7,
64
+ top_p=0.9,
65
+ repetition_penalty=1.1,
66
+ do_sample=True,
67
+ return_full_text=False
68
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
+ response = pipe(prompt)[0]['generated_text']
71
+ return response.strip()
72
 
73
  model_handler = ChatModel()
74
 
75
  def chat(message, history, model_choice):
76
  try:
77
+ response = model_handler.generate(message, model_choice, history)
78
+ return [(message, response)]
 
79
  except Exception as e:
80
  return [(message, f"Error: {str(e)}")]
81
 
82
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
83
+ gr.Markdown("# 🚀 Phi-3 vs Llama-3 Chatbot")
84
  with gr.Row():
85
  model_choice = gr.Dropdown(
86
  choices=["phi-3", "llama3-8b"],