polats commited on
Commit
1df0cfb
·
1 Parent(s): 67f4321

Persona endpoint: stop generation on client disconnect, fail-fast lock, lower token cap (prevents abandoned-gen lock pile-up)

Browse files
Files changed (2) hide show
  1. app.py +15 -10
  2. llm.py +14 -5
app.py CHANGED
@@ -234,6 +234,8 @@ async def persona_generate_stream(request: Request):
234
  seed = body.get("seed", "")
235
  unit_class = body.get("class") or body.get("unitClass") or ""
236
 
 
 
237
  async def gen():
238
  yield _sse("model", {"model": llm.model_id()})
239
  loop = asyncio.get_running_loop()
@@ -244,7 +246,7 @@ async def persona_generate_stream(request: Request):
244
  try:
245
  for chunk in llm.stream_chat(
246
  prompts.PERSONA_SYSTEM, prompts.persona_user_prompt(unit_class, seed),
247
- max_tokens=400, temperature=0.8,
248
  ):
249
  loop.call_soon_threadsafe(q.put_nowait, ("delta", chunk))
250
  except Exception as e: # LlmUnavailable or runtime error
@@ -254,15 +256,18 @@ async def persona_generate_stream(request: Request):
254
  threading.Thread(target=worker, daemon=True).start()
255
 
256
  raw_parts = []
257
- while True:
258
- kind, val = await q.get()
259
- if kind is DONE:
260
- break
261
- if kind == "error":
262
- yield _sse("error", {"error": val})
263
- return
264
- raw_parts.append(val)
265
- yield _sse("delta", {"content": val})
 
 
 
266
 
267
  try:
268
  p = persona_parse.parse_persona_json("".join(raw_parts))
 
234
  seed = body.get("seed", "")
235
  unit_class = body.get("class") or body.get("unitClass") or ""
236
 
237
+ stop = threading.Event() # set when the client disconnects → worker stops, lock frees
238
+
239
  async def gen():
240
  yield _sse("model", {"model": llm.model_id()})
241
  loop = asyncio.get_running_loop()
 
246
  try:
247
  for chunk in llm.stream_chat(
248
  prompts.PERSONA_SYSTEM, prompts.persona_user_prompt(unit_class, seed),
249
+ max_tokens=256, temperature=0.8, should_stop=stop.is_set,
250
  ):
251
  loop.call_soon_threadsafe(q.put_nowait, ("delta", chunk))
252
  except Exception as e: # LlmUnavailable or runtime error
 
256
  threading.Thread(target=worker, daemon=True).start()
257
 
258
  raw_parts = []
259
+ try:
260
+ while True:
261
+ kind, val = await q.get()
262
+ if kind is DONE:
263
+ break
264
+ if kind == "error":
265
+ yield _sse("error", {"error": val})
266
+ return
267
+ raw_parts.append(val)
268
+ yield _sse("delta", {"content": val})
269
+ finally:
270
+ stop.set() # client gone or stream done → release the model
271
 
272
  try:
273
  p = persona_parse.parse_persona_json("".join(raw_parts))
llm.py CHANGED
@@ -108,9 +108,18 @@ def _stream_local(system, user, max_tokens, temperature):
108
  yield delta
109
 
110
 
111
- def stream_chat(system, user, max_tokens=400, temperature=0.8):
112
- """Yield text chunks from the configured backend. Serialized by a module lock.
113
- Raises LlmUnavailable if no backend is available."""
114
- with _lock:
 
 
 
 
115
  gen = _stream_external if BASE_URL else _stream_local
116
- yield from gen(system, user, max_tokens, temperature)
 
 
 
 
 
 
108
  yield delta
109
 
110
 
111
+ def stream_chat(system, user, max_tokens=400, temperature=0.8, should_stop=None):
112
+ """Yield text chunks from the configured backend. Serialized by a module lock so
113
+ one CPU model never decodes two requests at once. `should_stop()` is polled each
114
+ chunk so an abandoned request (client gone) stops promptly and frees the lock.
115
+ Raises LlmUnavailable if no backend is available or the model is busy."""
116
+ if not _lock.acquire(timeout=2):
117
+ raise LlmUnavailable("the model is busy with another request — try again in a moment")
118
+ try:
119
  gen = _stream_external if BASE_URL else _stream_local
120
+ for chunk in gen(system, user, max_tokens, temperature):
121
+ if should_stop and should_stop():
122
+ break
123
+ yield chunk
124
+ finally:
125
+ _lock.release()