Wenye He commited on
Commit
f937954
·
verified ·
1 Parent(s): 4c5f924

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -33
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
3
- from threading import Thread
4
  import torch
5
  import time
6
 
@@ -19,6 +18,13 @@ MODEL_CONFIG = {
19
  }
20
  }
21
 
 
 
 
 
 
 
 
22
  class ChatModel:
23
  def __init__(self):
24
  self.models = {}
@@ -28,64 +34,66 @@ class ChatModel:
28
  if model_name not in self.models:
29
  config = MODEL_CONFIG[model_name]
30
 
31
- self.tokenizers[model_name] = AutoTokenizer.from_pretrained(config["model_name"])
32
- self.tokenizers[model_name].pad_token = self.tokenizers[model_name].eos_token
33
 
34
- self.models[model_name] = AutoModelForCausalLM.from_pretrained(
35
  config["model_name"],
 
36
  device_map="auto",
37
  torch_dtype=torch.float16,
38
- attn_implementation="flash_attention_2" if "phi-3" in model_name else "eager",
39
  trust_remote_code=True
40
  )
 
 
 
41
 
42
- def stream_response(self, message, model_name):
 
43
  self.load_model(model_name)
44
  config = MODEL_CONFIG[model_name]
45
- tokenizer = self.tokenizers[model_name]
46
- model = self.models[model_name]
47
 
 
48
  prompt = config["template"].format(message=message)
49
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
50
 
51
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, timeout=60)
52
- generation_kwargs = dict(
 
 
 
53
  **inputs,
54
- streamer=streamer,
55
- max_new_tokens=512,
56
  temperature=0.7,
57
  top_p=0.9,
58
- repetition_penalty=1.1,
59
  do_sample=True,
60
- pad_token_id=tokenizer.eos_token_id
61
  )
62
 
63
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
64
- thread.start()
 
 
 
 
 
 
 
 
65
 
66
- return streamer, tokenizer, time.time()
67
 
68
  model_handler = ChatModel()
69
 
70
  def chat(message, history, model_choice):
71
  try:
72
- streamer, tokenizer, start_time = model_handler.stream_response(message, model_choice)
73
- buffer = ""
74
-
75
- for new_text in streamer:
76
- buffer += new_text
77
- yield [(message, buffer)]
78
-
79
- elapsed_time = time.time() - start_time
80
- tokens = len(tokenizer.encode(buffer))
81
- token_speed = tokens / elapsed_time if elapsed_time > 0 else 0
82
- final_response = f"{buffer}\n\n⏱️ {elapsed_time:.2f}s | 🚀 {token_speed:.2f} tokens/s"
83
- yield [(message, final_response)]
84
  except Exception as e:
85
- yield [(message, f"Error: {str(e)}")]
86
 
87
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
88
- gr.Markdown("# 🚀 Streaming LLM Chatbot (Fixed)")
89
  with gr.Row():
90
  model_choice = gr.Dropdown(
91
  choices=["phi-3", "llama3-8b"],
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
 
3
  import torch
4
  import time
5
 
 
18
  }
19
  }
20
 
21
+ bnb_config = BitsAndBytesConfig(
22
+ load_in_4bit=True,
23
+ bnb_4bit_quant_type="nf4",
24
+ bnb_4bit_compute_dtype=torch.float16,
25
+ bnb_4bit_use_double_quant=True
26
+ )
27
+
28
  class ChatModel:
29
  def __init__(self):
30
  self.models = {}
 
34
  if model_name not in self.models:
35
  config = MODEL_CONFIG[model_name]
36
 
37
+ tokenizer = AutoTokenizer.from_pretrained(config["model_name"])
38
+ tokenizer.pad_token = tokenizer.eos_token
39
 
40
+ model = AutoModelForCausalLM.from_pretrained(
41
  config["model_name"],
42
+ quantization_config=bnb_config,
43
  device_map="auto",
44
  torch_dtype=torch.float16,
 
45
  trust_remote_code=True
46
  )
47
+
48
+ self.models[model_name] = model
49
+ self.tokenizers[model_name] = tokenizer
50
 
51
+ def generate(self, message, model_name, history):
52
+ start_time = time.time()
53
  self.load_model(model_name)
54
  config = MODEL_CONFIG[model_name]
 
 
55
 
56
+ # Format prompt
57
  prompt = config["template"].format(message=message)
 
58
 
59
+ # Tokenize input
60
+ inputs = self.tokenizers[model_name](prompt, return_tensors="pt").to("cuda")
61
+
62
+ # Generate response
63
+ outputs = self.models[model_name].generate(
64
  **inputs,
65
+ max_new_tokens=384,
 
66
  temperature=0.7,
67
  top_p=0.9,
 
68
  do_sample=True,
69
+ pad_token_id=self.tokenizers[model_name].eos_token_id
70
  )
71
 
72
+ # Decode response
73
+ response = self.tokenizers[model_name].decode(
74
+ outputs[0][inputs.input_ids.shape[-1]:],
75
+ skip_special_tokens=True
76
+ ).strip()
77
+
78
+ # Calculate metrics
79
+ elapsed_time = time.time() - start_time
80
+ tokens = outputs[0].shape[0] - inputs.input_ids.shape[-1]
81
+ tokens_per_sec = tokens / elapsed_time if elapsed_time > 0 else 0
82
 
83
+ return response, elapsed_time, tokens_per_sec
84
 
85
  model_handler = ChatModel()
86
 
87
  def chat(message, history, model_choice):
88
  try:
89
+ response, response_time, token_speed = model_handler.generate(message, model_choice, history)
90
+ formatted_response = f"{response}\n\n⏱️ Response Time: {response_time:.2f}s | 🚀 Speed: {token_speed:.2f} tokens/s"
91
+ return [(message, formatted_response)]
 
 
 
 
 
 
 
 
 
92
  except Exception as e:
93
+ return [(message, f"Error: {str(e)}")]
94
 
95
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
96
+ gr.Markdown("# 🚀 LLM Chatbot with Performance Metrics")
97
  with gr.Row():
98
  model_choice = gr.Dropdown(
99
  choices=["phi-3", "llama3-8b"],