Spaces:
Sleeping
Sleeping
| """ | |
| WebSocket handler for real-time MindSphere Coach updates. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| from typing import Any, Dict | |
| from fastapi import WebSocket, WebSocketDisconnect | |
| from .session import SessionManager | |
| async def websocket_endpoint( | |
| websocket: WebSocket, | |
| session_id: str, | |
| session_manager: SessionManager, | |
| ): | |
| """ | |
| WebSocket handler for a coaching session. | |
| Protocol: | |
| Client -> Server: | |
| {"type": "user_message", "content": "...", "message_type": "text|mc|choice"} | |
| {"type": "user_message", "content": "...", "answer_index": 0} | |
| {"type": "user_message", "content": "...", "choice": "accept"} | |
| {"type": "continue"} (for phase transitions) | |
| {"type": "set_empathy", "value": 0.7} | |
| Server -> Client: | |
| {"type": "assistant_message", "content": "...", "phase": "...", "action": "..."} | |
| {"type": "question", "data": {...}} | |
| {"type": "sphere_update", "data": {...}} | |
| {"type": "counterfactual", "data": {...}} | |
| {"type": "intervention", "data": {...}} | |
| {"type": "phase_change", "phase": "..."} | |
| {"type": "belief_update", "data": {...}} | |
| {"type": "error", "message": "..."} | |
| """ | |
| await websocket.accept() | |
| if not session_manager.session_exists(session_id): | |
| await websocket.send_json({"type": "error", "message": f"Session {session_id} not found"}) | |
| await websocket.close() | |
| return | |
| agent = session_manager.get_agent(session_id) | |
| if agent is None: | |
| await websocket.send_json({"type": "error", "message": "Agent not found"}) | |
| await websocket.close() | |
| return | |
| try: | |
| while True: | |
| data = await websocket.receive_json() | |
| msg_type = data.get("type", "") | |
| if msg_type == "user_message": | |
| content = data.get("content", "") | |
| message_type = data.get("message_type", "text") | |
| # Build user input | |
| user_input: Dict[str, Any] = {"answer": content} | |
| if "answer_index" in data: | |
| user_input["answer_index"] = data["answer_index"] | |
| if "choice" in data: | |
| user_input["choice"] = data["choice"] | |
| # Store in history | |
| session_manager.add_to_history(session_id, "user", content) | |
| # Process step | |
| result = agent.step(user_input) | |
| # Send assistant message | |
| if result.get("message"): | |
| session_manager.add_to_history(session_id, "assistant", result["message"]) | |
| await websocket.send_json({ | |
| "type": "assistant_message", | |
| "content": result["message"], | |
| "phase": result.get("phase", ""), | |
| }) | |
| # Send question if present | |
| if result.get("question"): | |
| await websocket.send_json({ | |
| "type": "question", | |
| "data": result["question"], | |
| }) | |
| # Send sphere update if present | |
| if result.get("sphere_data"): | |
| await websocket.send_json({ | |
| "type": "sphere_update", | |
| "data": result["sphere_data"], | |
| }) | |
| # Send counterfactual if present | |
| if result.get("counterfactual"): | |
| await websocket.send_json({ | |
| "type": "counterfactual", | |
| "data": result["counterfactual"], | |
| }) | |
| # Send intervention if present | |
| if result.get("intervention"): | |
| await websocket.send_json({ | |
| "type": "intervention", | |
| "data": result["intervention"], | |
| }) | |
| # Send phase change | |
| await websocket.send_json({ | |
| "type": "phase_change", | |
| "phase": result.get("phase", ""), | |
| "progress": result.get("progress"), | |
| "is_complete": result.get("is_complete", False), | |
| }) | |
| # Send belief update | |
| if result.get("tom_stats") or result.get("user_type_summary"): | |
| await websocket.send_json({ | |
| "type": "belief_update", | |
| "data": { | |
| "tom_stats": result.get("tom_stats"), | |
| "user_type": result.get("user_type_summary"), | |
| "beliefs": agent.get_belief_summary(), | |
| }, | |
| }) | |
| elif msg_type == "continue": | |
| # Trigger phase transition | |
| result = agent.step({}) | |
| await websocket.send_json({ | |
| "type": "phase_change", | |
| "phase": result.get("phase", ""), | |
| }) | |
| if result.get("message"): | |
| await websocket.send_json({ | |
| "type": "assistant_message", | |
| "content": result["message"], | |
| "phase": result.get("phase", ""), | |
| }) | |
| if result.get("sphere_data"): | |
| await websocket.send_json({ | |
| "type": "sphere_update", | |
| "data": result["sphere_data"], | |
| }) | |
| if result.get("intervention"): | |
| await websocket.send_json({ | |
| "type": "intervention", | |
| "data": result["intervention"], | |
| }) | |
| if result.get("counterfactual"): | |
| await websocket.send_json({ | |
| "type": "counterfactual", | |
| "data": result["counterfactual"], | |
| }) | |
| elif msg_type == "set_empathy": | |
| value = float(data.get("value", 0.5)) | |
| agent.set_empathy_dial(value) | |
| await websocket.send_json({ | |
| "type": "empathy_updated", | |
| "value": value, | |
| }) | |
| except WebSocketDisconnect: | |
| pass | |