import asyncio import base64 import json import logging from contextlib import suppress from typing import Any, Dict, Optional import gradio as gr from fastapi import FastAPI, File, UploadFile, WebSocket, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from pydantic import BaseModel, Field from brain import BrainManager from middleware import MCPMiddleware from unity_bridge import UnityBridge logging.basicConfig(level=logging.INFO) LOGGER = logging.getLogger(__name__) app = FastAPI(title="Embodied AI Teacher Platform", version="1.1.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) brain = BrainManager() middleware = MCPMiddleware() unity = UnityBridge() app.include_router(unity.router) class TeachRequest(BaseModel): text: str = Field(..., description="Student utterance or question") image_url: Optional[str] = Field(None, description="Optional multimodal image URL") async def _publish_speech_chunks(speech: str) -> None: for token in speech.split(): await middleware.publish("teacher.speech.chunk", {"token": token}) await asyncio.sleep(0.01) @app.get("/health") async def health() -> Dict[str, str]: return {"status": "ok"} @app.post("/teach") async def teach(req: TeachRequest) -> Dict[str, Any]: action_raw = await brain.generate_teacher_action(req.text, image_url=req.image_url) action = await middleware.apply_teacher_action(action_raw) await unity.broadcast_motion(action_raw) await _publish_speech_chunks(action.speech) return {"action": action.__dict__, "telemetry_count": len(middleware.telemetry)} @app.post("/speech/stream") async def speech_stream(text: str) -> StreamingResponse: async def chunk_stream(): for token in text.split(): yield f"{token} ".encode("utf-8") await asyncio.sleep(0.03) return StreamingResponse(chunk_stream(), media_type="audio/wav") @app.post("/speech/upload") async def speech_upload(file: UploadFile = File(...)) -> Dict[str, Any]: raw = await file.read() content_b64 = base64.b64encode(raw).decode("utf-8") return { "filename": file.filename, "bytes": len(raw), "preview": content_b64[:160], "note": "Integrate ASR model here for transcription.", } @app.websocket("/ws") async def classroom_ws(websocket: WebSocket) -> None: await websocket.accept() tasks: list[asyncio.Task] = [] async def pump(topic: str, event_type: str) -> None: async for event in middleware.subscribe(topic): await websocket.send_text( json.dumps( { "type": event_type, "topic": event.topic, "ts": event.ts, **event.payload, } ) ) topics = { "teacher.actions": "teacher_action", "teacher.board.write": "board_write", "teacher.board.draw": "board_draw", "teacher.speech.chunk": "speech_chunk", } try: for topic, event_type in topics.items(): tasks.append(asyncio.create_task(pump(topic, event_type))) while True: inbound = await websocket.receive_text() msg = json.loads(inbound) if msg.get("type") == "student_input": action_raw = await brain.generate_teacher_action( msg.get("text", ""), image_url=msg.get("image_url") ) action = await middleware.apply_teacher_action(action_raw) await unity.broadcast_motion(action.__dict__) await _publish_speech_chunks(action.speech) await websocket.send_text( json.dumps({"type": "ack", "state": action.teaching_state}) ) elif msg.get("type") == "telemetry_request": await websocket.send_text( json.dumps( { "type": "telemetry_snapshot", "events": middleware.get_telemetry_snapshot(), } ) ) except WebSocketDisconnect: LOGGER.info("Classroom client disconnected") finally: for task in tasks: task.cancel() with suppress(asyncio.CancelledError): await task def _gradio_teach(text: str) -> str: loop = asyncio.new_event_loop() try: action = loop.run_until_complete(brain.generate_teacher_action(text)) finally: loop.close() return json.dumps(action, indent=2) gradio_ui = gr.Interface( fn=_gradio_teach, inputs=gr.Textbox(label="Student Question"), outputs=gr.Code(language="json", label="MCP Teacher Action"), title="Embodied Teacher Brain Console", description="Fast introspection surface for Hugging Face Spaces.", ) app = gr.mount_gradio_app(app, gradio_ui, path="/gradio") if __name__ == "__main__": import uvicorn uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)