TradingGameAI / app.py
j-js's picture
Update app.py
267d20c verified
from __future__ import annotations
import os
import re
from typing import Any, Dict, Tuple
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse, JSONResponse
from context_parser import detect_help_mode
from conversation_logic import generate_response
from models import ChatRequest
from ui_html import HOME_HTML
from utils import clamp01, get_user_text
from retrieval_engine import RetrievalEngine
retriever = RetrievalEngine()
app = FastAPI(title="GMAT Solver v3", version="3.1.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def split_unity_message(full_text: str) -> Tuple[str, str]:
"""
Splits the Unity payload into:
- hidden_context
- actual_user_message
Expected Unity format:
<hidden context>
USER_MESSAGE:
<user text>
Falls back safely if marker is missing.
"""
if not full_text:
return "", ""
marker = "USER_MESSAGE:"
idx = full_text.find(marker)
if idx == -1:
return "", full_text.strip()
hidden_context = full_text[:idx].strip()
actual_user_message = full_text[idx + len(marker):].strip()
return hidden_context, actual_user_message
def extract_game_context_fields(hidden_context: str) -> Dict[str, str]:
"""
Pull simple fields out of the hidden Unity context so downstream logic
can use the real question/options instead of the whole raw blob.
"""
fields = {
"category": "",
"difficulty": "",
"question": "",
"options": "",
}
if not hidden_context:
return fields
category_match = re.search(r"Category:\s*(.+)", hidden_context)
difficulty_match = re.search(r"Difficulty:\s*(.+)", hidden_context)
question_match = re.search(r"Question:\s*(.+)", hidden_context, re.DOTALL)
options_match = re.search(
r"Options:\s*(.+?)(?:\nPlayer balance:|\nLast outcome:|$)",
hidden_context,
re.DOTALL,
)
if category_match:
fields["category"] = category_match.group(1).strip()
if difficulty_match:
fields["difficulty"] = difficulty_match.group(1).strip()
if question_match:
question_text = question_match.group(1).strip()
question_text = question_text.split("\nOptions:")[0].strip()
question_text = question_text.split("\nPlayer balance:")[0].strip()
question_text = question_text.split("\nLast outcome:")[0].strip()
fields["question"] = question_text
if options_match:
fields["options"] = options_match.group(1).strip()
return fields
@app.get("/health")
def health() -> Dict[str, Any]:
return {"ok": True, "app": "GMAT Solver v3 LIVE CHECK 777"}
@app.get("/", response_class=HTMLResponse)
def home() -> str:
return HOME_HTML
@app.post("/chat")
async def chat(request: Request) -> JSONResponse:
raw_body: Any = None
try:
raw_body = await request.json()
except Exception:
try:
raw_body = (await request.body()).decode("utf-8", errors="ignore")
except Exception:
raw_body = None
req_data: Dict[str, Any] = raw_body if isinstance(raw_body, dict) else {}
try:
req = ChatRequest(**req_data)
except Exception:
req = ChatRequest()
full_text = get_user_text(req, raw_body)
hidden_context, actual_user_message = split_unity_message(full_text)
game_fields = extract_game_context_fields(hidden_context)
# NEW
context = retriever.search(actual_user_message)
context_text = "\n".join(context)
tone = clamp01(req_data.get("tone", req.tone), 0.5)
verbosity = clamp01(req_data.get("verbosity", req.verbosity), 0.5)
transparency = clamp01(req_data.get("transparency", req.transparency), 0.5)
help_mode = detect_help_mode(
actual_user_message,
req_data.get("help_mode", req.help_mode),
)
result = generate_response(
raw_user_text=actual_user_message,
tone=tone,
verbosity=verbosity,
transparency=transparency,
help_mode=help_mode,
hidden_context=hidden_context,
chat_history=req_data.get("chat_history", []),
question_text=game_fields["question"],
options_text=game_fields["options"],
question_category=game_fields["category"],
question_difficulty=game_fields["difficulty"],
retrieval_context=context_text,
)
return JSONResponse(
{
"reply": result.reply,
"meta": {
"domain": result.domain,
"solved": result.solved,
"help_mode": result.help_mode,
"answer_letter": result.answer_letter,
"answer_value": result.answer_value,
},
}
)
if __name__ == "__main__":
import uvicorn
port = int(os.getenv("PORT", "7860"))
uvicorn.run(app, host="0.0.0.0", port=port)