Tim Luka Horstmann commited on
Commit ·
2f6b259
1
Parent(s): 7ee889e
Include context
Browse files
app.py
CHANGED
|
@@ -284,6 +284,29 @@ def _format_game_context_for_prompt(game_context: Optional[Union[str, Dict[str,
|
|
| 284 |
parts.append(f"Visited stations so far: {', '.join(uniq)}.")
|
| 285 |
except Exception:
|
| 286 |
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 287 |
return " ".join(parts)
|
| 288 |
except Exception:
|
| 289 |
return ""
|
|
@@ -456,6 +479,7 @@ class QueryRequest(BaseModel):
|
|
| 456 |
history: list
|
| 457 |
game_context: Optional[Union[str, Dict[str, Any]]] = None
|
| 458 |
mode: Optional[str] = None
|
|
|
|
| 459 |
|
| 460 |
class TTSRequest(BaseModel):
|
| 461 |
text: str
|
|
@@ -481,6 +505,13 @@ async def predict(request: Request, query_request: QueryRequest):
|
|
| 481 |
history = query_request.history
|
| 482 |
game_context = query_request.game_context
|
| 483 |
mode = (query_request.mode or '').lower() or None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 484 |
return StreamingResponse(stream_response(query, history, game_context, mode), media_type="text/event-stream")
|
| 485 |
|
| 486 |
@app.post("/api/tts")
|
|
|
|
| 284 |
parts.append(f"Visited stations so far: {', '.join(uniq)}.")
|
| 285 |
except Exception:
|
| 286 |
pass
|
| 287 |
+
# Memory transcript (last messages)
|
| 288 |
+
mem = game_context.get('__memory') if isinstance(game_context, dict) else None
|
| 289 |
+
if isinstance(mem, dict):
|
| 290 |
+
try:
|
| 291 |
+
transcript = mem.get('transcript') or []
|
| 292 |
+
if transcript:
|
| 293 |
+
# Take last few lines and embed compactly
|
| 294 |
+
lines = []
|
| 295 |
+
for m in transcript[-20:]:
|
| 296 |
+
role = (m.get('role') or '').strip()
|
| 297 |
+
src = (m.get('source') or '').strip()
|
| 298 |
+
sta = (m.get('stationName') or '').strip()
|
| 299 |
+
txt = (m.get('content') or '').replace('\n',' ').strip()
|
| 300 |
+
if len(txt) > 2000:
|
| 301 |
+
txt = txt[:2000] + '…'
|
| 302 |
+
label = role if role else 'msg'
|
| 303 |
+
if src or sta:
|
| 304 |
+
label += f"[{src}{'/' + sta if sta else ''}]"
|
| 305 |
+
lines.append(f"- {label}: {txt}")
|
| 306 |
+
if lines:
|
| 307 |
+
parts.append("Recent game exchanges:\n" + "\n".join(lines))
|
| 308 |
+
except Exception:
|
| 309 |
+
pass
|
| 310 |
return " ".join(parts)
|
| 311 |
except Exception:
|
| 312 |
return ""
|
|
|
|
| 479 |
history: list
|
| 480 |
game_context: Optional[Union[str, Dict[str, Any]]] = None
|
| 481 |
mode: Optional[str] = None
|
| 482 |
+
game_memory: Optional[Dict[str, Any]] = None
|
| 483 |
|
| 484 |
class TTSRequest(BaseModel):
|
| 485 |
text: str
|
|
|
|
| 505 |
history = query_request.history
|
| 506 |
game_context = query_request.game_context
|
| 507 |
mode = (query_request.mode or '').lower() or None
|
| 508 |
+
# Attach optional game_memory into context for prompt formatting
|
| 509 |
+
if query_request.game_memory is not None:
|
| 510 |
+
if isinstance(game_context, dict):
|
| 511 |
+
game_context = dict(game_context)
|
| 512 |
+
game_context['__memory'] = query_request.game_memory
|
| 513 |
+
else:
|
| 514 |
+
game_context = { 'context': game_context, '__memory': query_request.game_memory }
|
| 515 |
return StreamingResponse(stream_response(query, history, game_context, mode), media_type="text/event-stream")
|
| 516 |
|
| 517 |
@app.post("/api/tts")
|