Wenye He commited on
Commit
4c5f924
·
verified ·
1 Parent(s): 4479f26

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -18
app.py CHANGED
@@ -35,6 +35,7 @@ class ChatModel:
35
  config["model_name"],
36
  device_map="auto",
37
  torch_dtype=torch.float16,
 
38
  trust_remote_code=True
39
  )
40
 
@@ -47,9 +48,9 @@ class ChatModel:
47
  prompt = config["template"].format(message=message)
48
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
49
 
50
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
51
  generation_kwargs = dict(
52
- inputs.input_ids,
53
  streamer=streamer,
54
  max_new_tokens=512,
55
  temperature=0.7,
@@ -67,24 +68,24 @@ class ChatModel:
67
  model_handler = ChatModel()
68
 
69
  def chat(message, history, model_choice):
70
- # Initialize streaming
71
- streamer, tokenizer, start_time = model_handler.stream_response(message, model_choice)
72
- buffer = ""
73
-
74
- # Stream tokens
75
- for new_text in streamer:
76
- buffer += new_text
77
- yield [(message, buffer)]
78
-
79
- # Add performance metrics
80
- elapsed_time = time.time() - start_time
81
- tokens = len(tokenizer.encode(buffer))
82
- token_speed = tokens / elapsed_time if elapsed_time > 0 else 0
83
- final_response = f"{buffer}\n\n⏱️ {elapsed_time:.2f}s | 🚀 {token_speed:.2f} tokens/s"
84
- yield [(message, final_response)]
85
 
86
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
87
- gr.Markdown("# 🚀 Streaming LLM Chatbot")
88
  with gr.Row():
89
  model_choice = gr.Dropdown(
90
  choices=["phi-3", "llama3-8b"],
 
35
  config["model_name"],
36
  device_map="auto",
37
  torch_dtype=torch.float16,
38
+ attn_implementation="flash_attention_2" if "phi-3" in model_name else "eager",
39
  trust_remote_code=True
40
  )
41
 
 
48
  prompt = config["template"].format(message=message)
49
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
50
 
51
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, timeout=60)
52
  generation_kwargs = dict(
53
+ **inputs,
54
  streamer=streamer,
55
  max_new_tokens=512,
56
  temperature=0.7,
 
68
  model_handler = ChatModel()
69
 
70
  def chat(message, history, model_choice):
71
+ try:
72
+ streamer, tokenizer, start_time = model_handler.stream_response(message, model_choice)
73
+ buffer = ""
74
+
75
+ for new_text in streamer:
76
+ buffer += new_text
77
+ yield [(message, buffer)]
78
+
79
+ elapsed_time = time.time() - start_time
80
+ tokens = len(tokenizer.encode(buffer))
81
+ token_speed = tokens / elapsed_time if elapsed_time > 0 else 0
82
+ final_response = f"{buffer}\n\n⏱️ {elapsed_time:.2f}s | 🚀 {token_speed:.2f} tokens/s"
83
+ yield [(message, final_response)]
84
+ except Exception as e:
85
+ yield [(message, f"Error: {str(e)}")]
86
 
87
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
88
+ gr.Markdown("# 🚀 Streaming LLM Chatbot (Fixed)")
89
  with gr.Row():
90
  model_choice = gr.Dropdown(
91
  choices=["phi-3", "llama3-8b"],