TobDeBer commited on
Commit
002a426
·
1 Parent(s): ee7dcb0

multiturn chat

Browse files
Files changed (1) hide show
  1. app.py +32 -18
app.py CHANGED
@@ -1,8 +1,7 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, TextIteratorStreamer
4
  import time
5
- import random
6
  from threading import Thread
7
  import sys
8
  import os
@@ -45,19 +44,19 @@ def chat_predict(message, history, max_length, temperature, top_p, repetition_pe
45
  yield "⚠️ Please wait for the model to finish loading..."
46
  return
47
 
48
- if not message.strip():
49
- yield "⚠️ Please enter a prompt."
50
- return
51
-
52
  try:
53
- # Build conversation history
54
  messages = []
55
  if system_prompt:
56
  messages.append({"role": "system", "content": system_prompt})
57
 
58
- for user_msg, assistant_msg in history:
59
- messages.append({"role": "user", "content": user_msg})
60
- messages.append({"role": "assistant", "content": assistant_msg})
 
 
 
 
61
 
62
  messages.append({"role": "user", "content": message})
63
 
@@ -89,13 +88,25 @@ def chat_predict(message, history, max_length, temperature, top_p, repetition_pe
89
  generated_text = ""
90
  start_time = time.time()
91
  token_count = 0
 
 
92
 
93
  for new_text in streamer:
94
  generated_text += new_text
95
  token_count += 1
96
- yield generated_text
 
 
 
 
 
 
 
 
 
 
97
 
98
- # Append stats after generation is complete
99
  elapsed_time = time.time() - start_time
100
  if elapsed_time > 0:
101
  tps = token_count / elapsed_time
@@ -121,7 +132,7 @@ custom_theme = gr.themes.Soft(
121
  )
122
 
123
  # Build the Gradio interface
124
- with gr.Blocks(theme=custom_theme) as demo:
125
  gr.Markdown(
126
  """
127
  # 🤖 Smol LLM Chat
@@ -133,6 +144,8 @@ with gr.Blocks(theme=custom_theme) as demo:
133
  # Chat Interface
134
  chat_interface = gr.ChatInterface(
135
  fn=chat_predict,
 
 
136
  additional_inputs=[
137
  gr.Slider(
138
  minimum=50,
@@ -175,8 +188,9 @@ with gr.Blocks(theme=custom_theme) as demo:
175
  load_status = load_model()
176
  print(f"Startup load status: {load_status}")
177
 
178
- # Launch the application
179
- demo.launch(
180
- share=False,
181
- show_error=True
182
- )
 
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
4
  import time
 
5
  from threading import Thread
6
  import sys
7
  import os
 
44
  yield "⚠️ Please wait for the model to finish loading..."
45
  return
46
 
 
 
 
 
47
  try:
48
+ # Prepare messages for chat template
49
  messages = []
50
  if system_prompt:
51
  messages.append({"role": "system", "content": system_prompt})
52
 
53
+ # history is a list of dicts: [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
54
+ for msg in history:
55
+ # Clean up history messages (remove stats if they were appended)
56
+ content = msg["content"]
57
+ if "\n\n---\n*Generated" in content:
58
+ content = content.split("\n\n---\n*Generated")[0]
59
+ messages.append({"role": msg["role"], "content": content})
60
 
61
  messages.append({"role": "user", "content": message})
62
 
 
88
  generated_text = ""
89
  start_time = time.time()
90
  token_count = 0
91
+ last_update_time = start_time
92
+ current_stats = ""
93
 
94
  for new_text in streamer:
95
  generated_text += new_text
96
  token_count += 1
97
+
98
+ # Update stats every 0.2 seconds
99
+ current_time = time.time()
100
+ if current_time - last_update_time > 0.2:
101
+ elapsed = current_time - start_time
102
+ if elapsed > 0:
103
+ tps = token_count / elapsed
104
+ current_stats = f"\n\n---\n*Generating... ({tps:.1f} t/s)*"
105
+ last_update_time = current_time
106
+
107
+ yield generated_text + current_stats
108
 
109
+ # Final stats
110
  elapsed_time = time.time() - start_time
111
  if elapsed_time > 0:
112
  tps = token_count / elapsed_time
 
132
  )
133
 
134
  # Build the Gradio interface
135
+ with gr.Blocks(theme=custom_theme, fill_height=True) as demo:
136
  gr.Markdown(
137
  """
138
  # 🤖 Smol LLM Chat
 
144
  # Chat Interface
145
  chat_interface = gr.ChatInterface(
146
  fn=chat_predict,
147
+ type="messages",
148
+ fill_height=True,
149
  additional_inputs=[
150
  gr.Slider(
151
  minimum=50,
 
188
  load_status = load_model()
189
  print(f"Startup load status: {load_status}")
190
 
191
+ if __name__ == "__main__":
192
+ # Launch the application
193
+ demo.launch(
194
+ share=False,
195
+ show_error=True
196
+ )