Add model switching interruption and support for reasoning model tokens

#4
by treerats88 - opened
Files changed (1) hide show
  1. app.py +61 -6
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
3
  from threading import Thread
4
  import gc
5
  import os
@@ -49,10 +49,17 @@ class ModelManager:
49
  def __init__(self):
50
  self.model = None
51
  self.tokenizer = None
 
 
52
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
53
 
54
  model_manager = ModelManager()
55
 
 
 
 
 
 
56
  def get_system_stats(request: gr.Request = None):
57
  """Returns a dictionary of current system metrics with formatted strings."""
58
  mem = psutil.virtual_memory()
@@ -66,9 +73,13 @@ def get_system_stats(request: gr.Request = None):
66
 
67
  def load_new_model(model_id):
68
  """Loads the model and tokenizer dynamically into the global manager."""
 
 
 
69
  # Clear old model from memory
70
  model_manager.model = None
71
  model_manager.tokenizer = None
 
72
  yield f"Loading {model_id}..."
73
  gc.collect()
74
  if torch.cuda.is_available():
@@ -81,6 +92,7 @@ def load_new_model(model_id):
81
 
82
  model_manager.tokenizer = tokenizer
83
  model_manager.model = model
 
84
 
85
  yield f"Successfully loaded {model_id} on {model_manager.device.upper()}"
86
  except Exception as e:
@@ -91,15 +103,33 @@ def run_inference(user_prompt, max_tokens, temperature, top_k, top_p, rep_penalt
91
  if model_manager.model is None or model_manager.tokenizer is None:
92
  yield "Please load a model first.", "Model not loaded"
93
  return
 
 
 
94
 
95
  tokenizer = model_manager.tokenizer
96
  model = model_manager.model
 
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  # Tokenize input
99
- inputs = tokenizer([user_prompt], return_tensors="pt").to(model_manager.device)
100
 
101
  # Set up the streamer
102
- streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
103
 
104
  # Adjust variables based on the do_sample logic
105
  if not do_sample:
@@ -116,7 +146,8 @@ def run_inference(user_prompt, max_tokens, temperature, top_k, top_p, rep_penalt
116
  repetition_penalty=float(rep_penalty),
117
  no_repeat_ngram_size=int(ngram_size),
118
  do_sample=do_sample,
119
- pad_token_id=tokenizer.eos_token_id # Prevents padding warnings
 
120
  )
121
 
122
  start_time = time.time()
@@ -124,15 +155,39 @@ def run_inference(user_prompt, max_tokens, temperature, top_k, top_p, rep_penalt
124
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
125
  thread.start()
126
 
 
 
 
 
 
 
 
 
127
  # Yield output iteratively for the streaming effect
128
- generated_text = user_prompt
129
  token_count = 0
130
  for new_text in streamer:
 
 
 
 
131
  generated_text += new_text
132
  token_count += 1
133
  duration = time.time() - start_time
134
  tps = token_count / duration if duration > 0 else 0
135
- yield generated_text, f"Speed: {tps:.2f} tokens/sec"
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
  def clean_cache():
138
  if os.path.exists(HF_CACHE_DIR):
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList
3
  from threading import Thread
4
  import gc
5
  import os
 
49
  def __init__(self):
50
  self.model = None
51
  self.tokenizer = None
52
+ self.model_id = None
53
+ self.stop_generation = False # Added flag to instantly kill generation
54
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
55
 
56
  model_manager = ModelManager()
57
 
58
+ # Custom stopping criteria to halt the generation thread when loading a new model
59
+ class StopOnFlag(StoppingCriteria):
60
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
61
+ return model_manager.stop_generation
62
+
63
  def get_system_stats(request: gr.Request = None):
64
  """Returns a dictionary of current system metrics with formatted strings."""
65
  mem = psutil.virtual_memory()
 
73
 
74
  def load_new_model(model_id):
75
  """Loads the model and tokenizer dynamically into the global manager."""
76
+ # Stop any ongoing generation immediately
77
+ model_manager.stop_generation = True
78
+
79
  # Clear old model from memory
80
  model_manager.model = None
81
  model_manager.tokenizer = None
82
+ model_manager.model_id = None
83
  yield f"Loading {model_id}..."
84
  gc.collect()
85
  if torch.cuda.is_available():
 
92
 
93
  model_manager.tokenizer = tokenizer
94
  model_manager.model = model
95
+ model_manager.model_id = model_id
96
 
97
  yield f"Successfully loaded {model_id} on {model_manager.device.upper()}"
98
  except Exception as e:
 
103
  if model_manager.model is None or model_manager.tokenizer is None:
104
  yield "Please load a model first.", "Model not loaded"
105
  return
106
+
107
+ # Reset the stop flag for the new generation run
108
+ model_manager.stop_generation = False
109
 
110
  tokenizer = model_manager.tokenizer
111
  model = model_manager.model
112
+ model_id = model_manager.model_id
113
 
114
+ is_supra_reasoning = "Supra-50M-Reasoning" in model_id if model_id else False
115
+
116
+ if is_supra_reasoning:
117
+ SYSTEM_PROMPT = "Your role as an assistant involves thoroughly exploring questions through a systematic long thinking process before providing the final precise and accurate solutions."
118
+ prompt_to_encode = (
119
+ f"[SYSTEM]: {SYSTEM_PROMPT}\n\n"
120
+ f"[USER]: {user_prompt}\n\n"
121
+ f"[ASSISTANT]: <|begin_of_thought|>\n"
122
+ )
123
+ skip_special = False
124
+ else:
125
+ prompt_to_encode = user_prompt
126
+ skip_special = True
127
+
128
  # Tokenize input
129
+ inputs = tokenizer([prompt_to_encode], return_tensors="pt").to(model_manager.device)
130
 
131
  # Set up the streamer
132
+ streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=skip_special)
133
 
134
  # Adjust variables based on the do_sample logic
135
  if not do_sample:
 
146
  repetition_penalty=float(rep_penalty),
147
  no_repeat_ngram_size=int(ngram_size),
148
  do_sample=do_sample,
149
+ pad_token_id=tokenizer.eos_token_id, # Prevents padding warnings
150
+ stopping_criteria=StoppingCriteriaList([StopOnFlag()]) # Attach the stopping criteria
151
  )
152
 
153
  start_time = time.time()
 
155
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
156
  thread.start()
157
 
158
+ if is_supra_reasoning:
159
+ # Use plain text formatting rather than markdown symbols inside gr.Textbox
160
+ base_display = f"Prompt: {user_prompt}\n\n----------------------------------------\n\n"
161
+ generated_text = ""
162
+ else:
163
+ base_display = ""
164
+ generated_text = user_prompt
165
+
166
  # Yield output iteratively for the streaming effect
 
167
  token_count = 0
168
  for new_text in streamer:
169
+ # Immediately break out of the UI update loop if a new model is loaded
170
+ if model_manager.stop_generation:
171
+ break
172
+
173
  generated_text += new_text
174
  token_count += 1
175
  duration = time.time() - start_time
176
  tps = token_count / duration if duration > 0 else 0
177
+
178
+ display_text = generated_text
179
+
180
+ if is_supra_reasoning:
181
+ display_text = display_text.replace("<s>", "").replace("</s>", "")
182
+ if not display_text.startswith("🧠 Thinking Process:"):
183
+ display_text = "🧠 Thinking Process:\n" + display_text
184
+
185
+ display_text = display_text.replace("<|begin_of_thought|>", "🧠 Thinking Process:\n")
186
+ display_text = display_text.replace("<|end_of_thought|>", "\n\n")
187
+ display_text = display_text.replace("<|begin_of_solution|>", "✅ Final Answer:\n\n")
188
+ display_text = display_text.replace("<|end_of_solution|>", "")
189
+
190
+ yield base_display + display_text, f"Speed: {tps:.2f} tokens/sec"
191
 
192
  def clean_cache():
193
  if os.path.exists(HF_CACHE_DIR):