EYEDOL commited on
Commit
485e894
·
verified ·
1 Parent(s): f85ee42

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -10
app.py CHANGED
@@ -3,10 +3,12 @@
3
  Refactored Salama Assistant: text-only chatbot (STT and TTS removed)
4
  Drop this file into your Hugging Face Space (replace existing app.py) or run locally.
5
 
6
- This version:
7
- - Never passes device_map=None (avoids TypeError in accelerate)
8
- - Detects bitsandbytes availability and only requests 4-bit loading when safe
9
- - Keeps streaming responses into Gradio chat UI
 
 
10
  """
11
 
12
  import os
@@ -63,6 +65,10 @@ class WeeboAssistant:
63
  "You are an intelligent assistant. Answer questions briefly and accurately. "
64
  "Respond only in English. No long answers.\n"
65
  )
 
 
 
 
66
  self._init_models()
67
 
68
  def _init_models(self):
@@ -82,6 +88,15 @@ class WeeboAssistant:
82
  self.llm_tokenizer = AutoTokenizer.from_pretrained(ADAPTER_REPO_ID, use_fast=True)
83
  print("Loaded tokenizer from ADAPTER_REPO_ID")
84
 
 
 
 
 
 
 
 
 
 
85
  if torch.cuda.is_available():
86
  device_map = "auto"
87
  else:
@@ -106,6 +121,11 @@ class WeeboAssistant:
106
  BASE_MODEL_ID,
107
  **base_model_kwargs,
108
  )
 
 
 
 
 
109
  print("Base model loaded from", BASE_MODEL_ID)
110
  except Exception as e:
111
  raise RuntimeError(
@@ -132,6 +152,11 @@ class WeeboAssistant:
132
  ADAPTER_REPO_ID,
133
  **peft_kwargs,
134
  )
 
 
 
 
 
135
  print("PEFT adapter applied from", ADAPTER_REPO_ID)
136
  except Exception as e:
137
  raise RuntimeError(
@@ -165,27 +190,48 @@ class WeeboAssistant:
165
  prompt_lines.append("Assistant: ")
166
  prompt = "\n".join(prompt_lines)
167
 
168
- inputs = self.llm_tokenizer(prompt, return_tensors="pt")
 
169
  try:
170
  model_device = next(self.llm_model.parameters()).device
171
  except StopIteration:
172
  model_device = torch.device("cpu")
173
  inputs = {k: v.to(model_device) for k, v in inputs.items()}
174
 
 
175
  streamer = TextIteratorStreamer(self.llm_tokenizer, skip_prompt=True, skip_special_tokens=True)
176
 
 
 
 
 
 
177
  generation_kwargs = dict(
178
  input_ids=inputs["input_ids"],
179
  attention_mask=inputs.get("attention_mask", None),
180
- max_new_tokens=512,
181
- do_sample=True,
182
- temperature=0.6,
183
- top_p=0.9,
184
  streamer=streamer,
185
  eos_token_id=getattr(self.llm_tokenizer, "eos_token_id", None),
 
 
 
186
  )
187
 
188
- gen_thread = threading.Thread(target=self.llm_model.generate, kwargs=generation_kwargs, daemon=True)
 
 
 
 
 
 
 
 
 
 
 
189
  gen_thread.start()
190
 
191
  return streamer
 
3
  Refactored Salama Assistant: text-only chatbot (STT and TTS removed)
4
  Drop this file into your Hugging Face Space (replace existing app.py) or run locally.
5
 
6
+ Performance-focused tweaks:
7
+ - lower max_new_tokens
8
+ - use greedy decoding (do_sample=False) for speed
9
+ - call generate() under torch.no_grad()
10
+ - set model.config.use_cache = True
11
+ - other minor safe optimizations
12
  """
13
 
14
  import os
 
65
  "You are an intelligent assistant. Answer questions briefly and accurately. "
66
  "Respond only in English. No long answers.\n"
67
  )
68
+ # set sensible defaults for generation speed
69
+ self.MAX_NEW_TOKENS = 256 # lowered from 512 for speed
70
+ self.DO_SAMPLE = False # greedy = faster; set True if you need randomness
71
+ self.NUM_BEAMS = 1 # keep 1 for greedy; increase for beam search (slower)
72
  self._init_models()
73
 
74
  def _init_models(self):
 
88
  self.llm_tokenizer = AutoTokenizer.from_pretrained(ADAPTER_REPO_ID, use_fast=True)
89
  print("Loaded tokenizer from ADAPTER_REPO_ID")
90
 
91
+ # ensure tokenizer has pad_token_id (some HF models lack it)
92
+ if getattr(self.llm_tokenizer, "pad_token_id", None) is None:
93
+ # try to set eos_token_id as pad if pad missing
94
+ if getattr(self.llm_tokenizer, "eos_token_id", None) is not None:
95
+ self.llm_tokenizer.pad_token_id = self.llm_tokenizer.eos_token_id
96
+ else:
97
+ # fallback to 0 (not ideal but prevents crashes)
98
+ self.llm_tokenizer.pad_token_id = 0
99
+
100
  if torch.cuda.is_available():
101
  device_map = "auto"
102
  else:
 
121
  BASE_MODEL_ID,
122
  **base_model_kwargs,
123
  )
124
+ # make sure use_cache is enabled for faster autoregressive generation
125
+ try:
126
+ self.llm_model.config.use_cache = True
127
+ except Exception:
128
+ pass
129
  print("Base model loaded from", BASE_MODEL_ID)
130
  except Exception as e:
131
  raise RuntimeError(
 
152
  ADAPTER_REPO_ID,
153
  **peft_kwargs,
154
  )
155
+ # ensure adapter-wrapped model also has use_cache
156
+ try:
157
+ self.llm_model.config.use_cache = True
158
+ except Exception:
159
+ pass
160
  print("PEFT adapter applied from", ADAPTER_REPO_ID)
161
  except Exception as e:
162
  raise RuntimeError(
 
190
  prompt_lines.append("Assistant: ")
191
  prompt = "\n".join(prompt_lines)
192
 
193
+ # Tokenize
194
+ inputs = self.llm_tokenizer(prompt, return_tensors="pt", padding=False)
195
  try:
196
  model_device = next(self.llm_model.parameters()).device
197
  except StopIteration:
198
  model_device = torch.device("cpu")
199
  inputs = {k: v.to(model_device) for k, v in inputs.items()}
200
 
201
+ # Streamer unchanged (still yields chunks)
202
  streamer = TextIteratorStreamer(self.llm_tokenizer, skip_prompt=True, skip_special_tokens=True)
203
 
204
+ # Prefill some generation kwargs optimized for speed
205
+ input_len = inputs["input_ids"].shape[1]
206
+ max_new = self.MAX_NEW_TOKENS
207
+ max_length = input_len + max_new
208
+
209
  generation_kwargs = dict(
210
  input_ids=inputs["input_ids"],
211
  attention_mask=inputs.get("attention_mask", None),
212
+ max_length=max_length, # prefer max_length = input_len + max_new_tokens
213
+ max_new_tokens=max_new, # kept for clarity / compatibility
214
+ do_sample=self.DO_SAMPLE, # greedy if False -> faster
215
+ num_beams=self.NUM_BEAMS, # beam search >1 slows down; keep 1 for speed
216
  streamer=streamer,
217
  eos_token_id=getattr(self.llm_tokenizer, "eos_token_id", None),
218
+ pad_token_id=getattr(self.llm_tokenizer, "pad_token_id", None),
219
+ use_cache=True,
220
+ early_stopping=True,
221
  )
222
 
223
+ # Run generate under no_grad for speed / memory
224
+ def _generate_thread():
225
+ with torch.no_grad():
226
+ try:
227
+ # call generate on model (PEFT-wrapped)
228
+ self.llm_model.generate(**generation_kwargs)
229
+ except Exception as e:
230
+ # if streaming fails, put an error chunk into streamer by raising
231
+ # streamer does not provide a direct API to inject text; print to log
232
+ print("Generation error:", e)
233
+
234
+ gen_thread = threading.Thread(target=_generate_thread, daemon=True)
235
  gen_thread.start()
236
 
237
  return streamer