Valtry commited on
Commit
d70c8a7
·
verified ·
1 Parent(s): 034245d

Upload 2 files

Browse files
Files changed (2) hide show
  1. agent.py +2 -2
  2. model.py +30 -15
agent.py CHANGED
@@ -340,7 +340,7 @@ class AgentRouter:
340
  "tools_used": [],
341
  }
342
 
343
- final_text = accumulated.strip()
344
  if not final_text:
345
  final_text = self.model.generate(message=message, memory_context=memory_context, tool_context="")
346
 
@@ -417,7 +417,7 @@ class AgentRouter:
417
  "tools_used": tools_used,
418
  }
419
 
420
- final_text = accumulated.strip()
421
  if not final_text:
422
  final_text = self.model.generate(message=message, memory_context=memory_context, tool_context=tool_context)
423
 
 
340
  "tools_used": [],
341
  }
342
 
343
+ final_text = self.model.clean_response(accumulated)
344
  if not final_text:
345
  final_text = self.model.generate(message=message, memory_context=memory_context, tool_context="")
346
 
 
417
  "tools_used": tools_used,
418
  }
419
 
420
+ final_text = self.model.clean_response(accumulated)
421
  if not final_text:
422
  final_text = self.model.generate(message=message, memory_context=memory_context, tool_context=tool_context)
423
 
model.py CHANGED
@@ -161,6 +161,9 @@ class ModelManager:
161
 
162
  return cleaned
163
 
 
 
 
164
  def generate(self, message: str, memory_context: str = "", tool_context: str = "") -> str:
165
  self.load()
166
  max_new_tokens = self.dynamic_token_budget(message)
@@ -233,30 +236,42 @@ class ModelManager:
233
  worker = threading.Thread(target=model.generate, kwargs=generation_kwargs, daemon=True)
234
  worker.start()
235
 
236
- pieces = []
237
- truncated = False
 
 
 
238
  for piece in streamer:
239
  if not piece:
240
  continue
241
 
242
- if truncated:
243
- continue
244
 
245
- emitted = piece
246
- for marker in ["\nUser:", "\nAssistant:", "\nSystem:"]:
247
- idx = emitted.find(marker)
248
- if idx != -1:
249
- emitted = emitted[:idx]
250
- truncated = True
251
- break
252
 
253
- if emitted:
254
- pieces.append(emitted)
255
- yield emitted
 
 
 
 
 
 
 
256
 
257
  worker.join(timeout=0.1)
258
 
259
- final_text = self._clean_response("".join(pieces))
 
 
 
 
 
 
260
  if final_text:
261
  self._set_cached(key, final_text)
262
 
 
161
 
162
  return cleaned
163
 
164
+ def clean_response(self, text: str) -> str:
165
+ return self._clean_response(text)
166
+
167
  def generate(self, message: str, memory_context: str = "", tool_context: str = "") -> str:
168
  self.load()
169
  max_new_tokens = self.dynamic_token_budget(message)
 
236
  worker = threading.Thread(target=model.generate, kwargs=generation_kwargs, daemon=True)
237
  worker.start()
238
 
239
+ markers = ["\nUser:", "\nAssistant:", "\nSystem:", "User:", "Assistant:", "System:"]
240
+ buffer = ""
241
+ yielded_len = 0
242
+ stop_idx = -1
243
+
244
  for piece in streamer:
245
  if not piece:
246
  continue
247
 
248
+ buffer += piece
 
249
 
250
+ # Find earliest marker in accumulated text (handles marker split across chunks).
251
+ marker_positions = [buffer.find(m) for m in markers if buffer.find(m) != -1]
252
+ if marker_positions:
253
+ stop_idx = min(marker_positions)
 
 
 
254
 
255
+ # Hold a short tail so markers crossing boundaries are still detected safely.
256
+ safe_upto = len(buffer) - 20 if stop_idx == -1 else stop_idx
257
+ if safe_upto > yielded_len:
258
+ out = buffer[yielded_len:safe_upto]
259
+ if out:
260
+ yield out
261
+ yielded_len = safe_upto
262
+
263
+ if stop_idx != -1:
264
+ break
265
 
266
  worker.join(timeout=0.1)
267
 
268
+ if stop_idx == -1 and yielded_len < len(buffer):
269
+ out = buffer[yielded_len:]
270
+ if out:
271
+ yield out
272
+
273
+ truncated_final = buffer[:stop_idx] if stop_idx != -1 else buffer
274
+ final_text = self._clean_response(truncated_final)
275
  if final_text:
276
  self._set_cached(key, final_text)
277