Spaces:
Sleeping
Sleeping
| from contextlib import asynccontextmanager | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from typing import Optional | |
| import os | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| from server.env import MindReadEnv | |
| from server.models import ( | |
| MindReadObservation, | |
| StepResult, | |
| SubmitResult, | |
| TaskMeta, | |
| Secret, | |
| GenerateSecretRequest, | |
| HealthResponse, | |
| AskQuestionAction, | |
| ) | |
| from server.secret_generator import generate_secret | |
| env = MindReadEnv() | |
| async def lifespan(app: FastAPI): | |
| # warm up the embedding model on startup | |
| from server.reward import get_embedder | |
| get_embedder() | |
| yield | |
| app = FastAPI( | |
| title="MindRead: Theory of Mind RL Environment", | |
| version="1.0.0", | |
| description=( | |
| "Interactive Theory of Mind training environment. " | |
| "An LLM agent (Detective) must infer a hidden mental state " | |
| "by asking strategic questions to an Oracle. " | |
| "Trains functional theory of mind — the ability to adapt questioning " | |
| "strategy based on Oracle responses." | |
| ), | |
| lifespan=lifespan, | |
| ) | |
| def health(): | |
| return HealthResponse( | |
| status="ok", | |
| version="1.0.0", | |
| oracle_backend="groq/llama-3.1-8b-instant", | |
| ) | |
| def get_tasks(): | |
| return env.get_tasks() | |
| class ResetRequest(BaseModel): | |
| task_id: str | |
| secret_id: Optional[str] = None | |
| def reset(request: ResetRequest): | |
| try: | |
| return env.reset(task_id=request.task_id, secret_id=request.secret_id) | |
| except ValueError as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| except RuntimeError as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| class StepRequest(BaseModel): | |
| episode_id: str | |
| action: AskQuestionAction | |
| def step(request: StepRequest): | |
| action = request.action | |
| if action.action != "ask_question": | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Use /submit to submit a hypothesis. /step only accepts ask_question.", | |
| ) | |
| if not action.question or not action.question.strip(): | |
| raise HTTPException(status_code=400, detail="Question must not be empty.") | |
| try: | |
| return env.step(request.episode_id, action.question.strip()) | |
| except KeyError as e: | |
| raise HTTPException(status_code=404, detail=str(e)) | |
| except ValueError as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| class SubmitRequest(BaseModel): | |
| episode_id: str | |
| hypothesis: str | |
| category_prediction: Optional[str] = None | |
| def submit(request: SubmitRequest): | |
| if not request.hypothesis or not request.hypothesis.strip(): | |
| raise HTTPException(status_code=400, detail="Hypothesis must not be empty.") | |
| try: | |
| return env.submit( | |
| episode_id=request.episode_id, | |
| hypothesis=request.hypothesis.strip(), | |
| category_prediction=request.category_prediction, | |
| ) | |
| except KeyError as e: | |
| raise HTTPException(status_code=404, detail=str(e)) | |
| except ValueError as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| def get_state(episode_id: str): | |
| try: | |
| return env.get_state(episode_id) | |
| except KeyError as e: | |
| raise HTTPException(status_code=404, detail=str(e)) | |
| def generate_secret_endpoint(request: GenerateSecretRequest): | |
| try: | |
| secret_data = generate_secret( | |
| category=request.category, | |
| difficulty=request.difficulty, | |
| domain=request.domain, | |
| ) | |
| secret = Secret(**secret_data) | |
| env.add_secret(secret) | |
| obs = env.reset(task_id=secret.task_id, secret_id=secret.id) | |
| return {"secret": secret_data, "episode_id": obs.episode_id} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |