Wenye He commited on
Commit
4479f26
·
verified ·
1 Parent(s): 1cb71a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -40
app.py CHANGED
@@ -1,7 +1,8 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
 
3
  import torch
4
- import time # Added for timing
5
 
6
  MODEL_CONFIG = {
7
  "phi-3": {
@@ -18,13 +19,6 @@ MODEL_CONFIG = {
18
  }
19
  }
20
 
21
- bnb_config = BitsAndBytesConfig(
22
- load_in_4bit=True,
23
- bnb_4bit_quant_type="nf4",
24
- bnb_4bit_compute_dtype=torch.float16,
25
- bnb_4bit_use_double_quant=True
26
- )
27
-
28
  class ChatModel:
29
  def __init__(self):
30
  self.models = {}
@@ -34,62 +28,63 @@ class ChatModel:
34
  if model_name not in self.models:
35
  config = MODEL_CONFIG[model_name]
36
 
37
- tokenizer = AutoTokenizer.from_pretrained(config["model_name"])
38
- tokenizer.pad_token = tokenizer.eos_token
39
 
40
- model = AutoModelForCausalLM.from_pretrained(
41
  config["model_name"],
42
- quantization_config=bnb_config,
43
  device_map="auto",
44
  torch_dtype=torch.float16,
45
  trust_remote_code=True
46
  )
47
-
48
- self.models[model_name] = model
49
- self.tokenizers[model_name] = tokenizer
50
 
51
- def generate(self, message, model_name, history):
52
- start_time = time.time() # Start timing
53
  self.load_model(model_name)
54
  config = MODEL_CONFIG[model_name]
 
 
55
 
56
- # Format prompt
57
  prompt = config["template"].format(message=message)
 
58
 
59
- # Create pipeline
60
- pipe = pipeline(
61
- "text-generation",
62
- model=self.models[model_name],
63
- tokenizer=self.tokenizers[model_name],
64
- max_new_tokens=384,
65
  temperature=0.7,
66
  top_p=0.9,
67
  repetition_penalty=1.1,
68
  do_sample=True,
69
- return_full_text=False
70
  )
71
 
72
- response = pipe(prompt)[0]['generated_text'].strip()
73
-
74
- # Calculate metrics
75
- elapsed_time = time.time() - start_time
76
- tokens = len(self.tokenizers[model_name].encode(response))
77
- tokens_per_sec = tokens / elapsed_time if elapsed_time > 0 else 0
78
 
79
- return response, elapsed_time, tokens_per_sec
80
 
81
  model_handler = ChatModel()
82
 
83
  def chat(message, history, model_choice):
84
- try:
85
- response, response_time, token_speed = model_handler.generate(message, model_choice, history)
86
- formatted_response = f"{response}\n\n⏱️ Response Time: {response_time:.2f}s | 🚀 Speed: {token_speed:.2f} tokens/s"
87
- return [(message, formatted_response)]
88
- except Exception as e:
89
- return [(message, f"Error: {str(e)}")]
 
 
 
 
 
 
 
 
 
90
 
91
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
92
- gr.Markdown("# 🚀 LLM Chatbot with Performance Metrics")
93
  with gr.Row():
94
  model_choice = gr.Dropdown(
95
  choices=["phi-3", "llama3-8b"],
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
3
+ from threading import Thread
4
  import torch
5
+ import time
6
 
7
  MODEL_CONFIG = {
8
  "phi-3": {
 
19
  }
20
  }
21
 
 
 
 
 
 
 
 
22
  class ChatModel:
23
  def __init__(self):
24
  self.models = {}
 
28
  if model_name not in self.models:
29
  config = MODEL_CONFIG[model_name]
30
 
31
+ self.tokenizers[model_name] = AutoTokenizer.from_pretrained(config["model_name"])
32
+ self.tokenizers[model_name].pad_token = self.tokenizers[model_name].eos_token
33
 
34
+ self.models[model_name] = AutoModelForCausalLM.from_pretrained(
35
  config["model_name"],
 
36
  device_map="auto",
37
  torch_dtype=torch.float16,
38
  trust_remote_code=True
39
  )
 
 
 
40
 
41
+ def stream_response(self, message, model_name):
 
42
  self.load_model(model_name)
43
  config = MODEL_CONFIG[model_name]
44
+ tokenizer = self.tokenizers[model_name]
45
+ model = self.models[model_name]
46
 
 
47
  prompt = config["template"].format(message=message)
48
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
49
 
50
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
51
+ generation_kwargs = dict(
52
+ inputs.input_ids,
53
+ streamer=streamer,
54
+ max_new_tokens=512,
 
55
  temperature=0.7,
56
  top_p=0.9,
57
  repetition_penalty=1.1,
58
  do_sample=True,
59
+ pad_token_id=tokenizer.eos_token_id
60
  )
61
 
62
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
63
+ thread.start()
 
 
 
 
64
 
65
+ return streamer, tokenizer, time.time()
66
 
67
  model_handler = ChatModel()
68
 
69
  def chat(message, history, model_choice):
70
+ # Initialize streaming
71
+ streamer, tokenizer, start_time = model_handler.stream_response(message, model_choice)
72
+ buffer = ""
73
+
74
+ # Stream tokens
75
+ for new_text in streamer:
76
+ buffer += new_text
77
+ yield [(message, buffer)]
78
+
79
+ # Add performance metrics
80
+ elapsed_time = time.time() - start_time
81
+ tokens = len(tokenizer.encode(buffer))
82
+ token_speed = tokens / elapsed_time if elapsed_time > 0 else 0
83
+ final_response = f"{buffer}\n\n⏱️ {elapsed_time:.2f}s | 🚀 {token_speed:.2f} tokens/s"
84
+ yield [(message, final_response)]
85
 
86
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
87
+ gr.Markdown("# 🚀 Streaming LLM Chatbot")
88
  with gr.Row():
89
  model_choice = gr.Dropdown(
90
  choices=["phi-3", "llama3-8b"],