GameAI / app.py
j-js's picture
Update app.py
3b94624 verified
from __future__ import annotations
from typing import Any, Dict, List, Optional
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse, JSONResponse
from context_parser import (
detect_intent,
extract_game_context_fields,
intent_to_help_mode,
split_unity_message,
)
from conversation_logic import ConversationEngine
from generator_engine import GeneratorEngine
from logging_store import LoggingStore
from models import ChatRequest, EventLogRequest, SessionFinalizeRequest, SessionStartRequest
from question_support_loader import question_support_bank
from retrieval_engine import RetrievalEngine
from ui_html import HOME_HTML
from utils import clamp01, get_user_text
from pathlib import Path
import subprocess
import sys
retriever = RetrievalEngine()
generator = GeneratorEngine()
engine = ConversationEngine(retriever=retriever, generator=generator)
store = LoggingStore()
question_support_bank.load()
app = FastAPI(title="GameAI")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Lightweight in-memory chat session state cache.
CHAT_SESSION_STATE: Dict[str, Dict[str, Any]] = {}
def _as_dict(value: Any) -> Dict[str, Any]:
return value if isinstance(value, dict) else {}
def _extract_session_id(req_data: Dict[str, Any], req: ChatRequest) -> Optional[str]:
candidates = [
req_data.get("session_id"),
getattr(req, "session_id", None),
req_data.get("conversation_id"),
getattr(req, "conversation_id", None),
]
for candidate in candidates:
if isinstance(candidate, str) and candidate.strip():
return candidate.strip()
return None
def _extract_chat_history(req_data: Dict[str, Any], req: ChatRequest) -> List[Dict[str, Any]]:
candidates = [
req_data.get("chat_history"),
req_data.get("history"),
getattr(req, "chat_history", None),
getattr(req, "history", None),
]
for candidate in candidates:
if isinstance(candidate, list):
return [item for item in candidate if isinstance(item, dict)]
return []
def _recover_session_state_from_history(chat_history: List[Dict[str, Any]]) -> Dict[str, Any]:
for item in reversed(chat_history):
if not isinstance(item, dict):
continue
direct_state = item.get("session_state")
if isinstance(direct_state, dict) and direct_state:
return dict(direct_state)
meta = item.get("meta")
if isinstance(meta, dict):
meta_state = meta.get("session_state")
if isinstance(meta_state, dict) and meta_state:
return dict(meta_state)
return {}
def _merge_session_state(
cached_state: Dict[str, Any],
incoming_state: Dict[str, Any],
history_state: Dict[str, Any],
parsed_question_text: str,
parsed_hint_stage: int,
parsed_help_mode: str,
parsed_intent: str,
parsed_topic: str,
parsed_category: str,
parsed_user_last_input_type: str,
parsed_built_on_previous_turn: bool,
) -> Dict[str, Any]:
state: Dict[str, Any] = {}
if cached_state:
state.update(cached_state)
if history_state:
state.update(history_state)
if incoming_state:
state.update(incoming_state)
if parsed_question_text:
state["question_text"] = parsed_question_text
if parsed_hint_stage:
state["hint_stage"] = parsed_hint_stage
if parsed_user_last_input_type:
state["user_last_input_type"] = parsed_user_last_input_type
if parsed_built_on_previous_turn:
state["built_on_previous_turn"] = parsed_built_on_previous_turn
if parsed_help_mode:
state["help_mode"] = parsed_help_mode
if parsed_intent:
state["intent"] = parsed_intent
if parsed_topic:
state["topic"] = parsed_topic
if parsed_category:
state["category"] = parsed_category
return state
@app.get("/health")
def health() -> Dict[str, Any]:
return {
"ok": True,
"app": "GameAI",
"generator_available": generator.available(),
"question_support_loaded": True,
}
@app.get("/", response_class=HTMLResponse)
def home() -> str:
return HOME_HTML
@app.post("/chat")
async def chat(request: Request) -> JSONResponse:
try:
try:
raw_body: Any = await request.json()
except Exception:
try:
raw_body = (await request.body()).decode("utf-8", errors="ignore")
except Exception:
raw_body = None
req_data: Dict[str, Any] = raw_body if isinstance(raw_body, dict) else {}
req = ChatRequest(**req_data) if isinstance(req_data, dict) else ChatRequest()
full_text = get_user_text(req, raw_body)
parsed = split_unity_message(full_text)
hidden_context = parsed.get("hidden_context", "")
actual_user_message = (parsed.get("user_text", "") or "").strip()
parsed_question_text = (parsed.get("question_text", "") or "").strip()
parsed_hint_stage = int(parsed.get("hint_stage", 0) or 0)
parsed_help_mode = (parsed.get("help_mode", "") or "").strip()
parsed_intent = (parsed.get("intent", "") or "").strip()
parsed_topic = (parsed.get("topic", "") or "").strip()
parsed_category = (parsed.get("category", "") or "").strip()
parsed_user_last_input_type = (parsed.get("user_last_input_type", "") or "").strip()
parsed_built_on_previous_turn = bool(parsed.get("built_on_previous_turn", False))
game_fields = extract_game_context_fields(hidden_context)
chat_history = _extract_chat_history(req_data, req)
incoming_session_state = _as_dict(req_data.get("session_state", getattr(req, "session_state", None)))
history_session_state = _recover_session_state_from_history(chat_history)
session_id = _extract_session_id(req_data, req)
cached_session_state = CHAT_SESSION_STATE.get(session_id, {}) if session_id else {}
session_state = _merge_session_state(
cached_state=_as_dict(cached_session_state),
incoming_state=incoming_session_state,
history_state=history_session_state,
parsed_question_text=parsed_question_text,
parsed_hint_stage=parsed_hint_stage,
parsed_help_mode=parsed_help_mode,
parsed_intent=parsed_intent,
parsed_topic=parsed_topic,
parsed_category=parsed_category,
parsed_user_last_input_type=parsed_user_last_input_type,
parsed_built_on_previous_turn=parsed_built_on_previous_turn,
)
question_text = (
(getattr(req, "question_text", None) or "").strip()
or parsed_question_text
or game_fields.get("question", "")
or str(session_state.get("question_text", "") or "").strip()
)
options_text = getattr(req, "options_text", None) or game_fields.get("options", [])
question_id = req_data.get("question_id") or getattr(req, "question_id", None) or session_state.get("question_id")
category = (
req_data.get("category")
or getattr(req, "category", None)
or parsed_category
or game_fields.get("category")
or session_state.get("category")
)
tone = clamp01(req_data.get("tone", getattr(req, "tone", 0.5)), 0.5)
verbosity = clamp01(req_data.get("verbosity", getattr(req, "verbosity", 0.5)), 0.5)
transparency = clamp01(req_data.get("transparency", getattr(req, "transparency", 0.5)), 0.5)
incoming_help_mode = req_data.get("help_mode") or getattr(req, "help_mode", None) or parsed_help_mode or None
explicit_intent = req_data.get("intent") or getattr(req, "intent", None) or parsed_intent or None
resolved_user_text = req_data.get("raw_user_text") or actual_user_message or full_text or ""
resolved_user_text = str(resolved_user_text).strip()
intent = explicit_intent or detect_intent(resolved_user_text, incoming_help_mode)
help_mode = incoming_help_mode or intent_to_help_mode(intent)
result = engine.generate_response(
raw_user_text=resolved_user_text,
tone=tone,
verbosity=verbosity,
transparency=transparency,
intent=intent,
help_mode=help_mode,
chat_history=chat_history,
question_text=question_text,
options_text=options_text,
question_id=question_id,
session_state=session_state,
category=category,
)
meta: Dict[str, Any] = {
"domain": result.domain,
"solved": result.solved,
"help_mode": result.help_mode,
"answer_letter": result.answer_letter,
"answer_value": result.answer_value,
"topic": result.topic,
"used_retrieval": result.used_retrieval,
"used_generator": result.used_generator,
}
if isinstance(result.meta, dict):
meta.update(result.meta)
returned_session_state = _as_dict(meta.get("session_state"))
if session_id and returned_session_state:
CHAT_SESSION_STATE[session_id] = dict(returned_session_state)
return JSONResponse({"reply": result.reply, "meta": meta})
except Exception as e:
return JSONResponse({"error": type(e).__name__, "detail": str(e)}, status_code=500)
@app.post("/log/session/start")
def log_session_start(payload: SessionStartRequest) -> Dict[str, Any]:
return store.start_session(payload.session_id, payload.user_id, payload.condition, payload.metadata)
@app.post("/log/event")
def log_event(payload: EventLogRequest) -> Dict[str, Any]:
return store.log_event(payload.session_id, payload.event_type, payload.payload, payload.timestamp)
@app.post("/log/session/finalize")
def log_session_finalize(payload: SessionFinalizeRequest) -> Dict[str, Any]:
return store.finalize_session(payload.session_id, payload.summary)
@app.get("/research/sessions")
def research_sessions() -> Dict[str, Any]:
return {"sessions": store.list_sessions()}
@app.get("/research/session/{session_id}")
def research_session(session_id: str) -> Dict[str, Any]:
return store.get_session(session_id)