mindsphere_coach / src /mindsphere /api /websocket.py
Mahault
Initial commit: MindSphere Coach — ToM-powered coaching agent
157b149
"""
WebSocket handler for real-time MindSphere Coach updates.
"""
from __future__ import annotations
import json
from typing import Any, Dict
from fastapi import WebSocket, WebSocketDisconnect
from .session import SessionManager
async def websocket_endpoint(
websocket: WebSocket,
session_id: str,
session_manager: SessionManager,
):
"""
WebSocket handler for a coaching session.
Protocol:
Client -> Server:
{"type": "user_message", "content": "...", "message_type": "text|mc|choice"}
{"type": "user_message", "content": "...", "answer_index": 0}
{"type": "user_message", "content": "...", "choice": "accept"}
{"type": "continue"} (for phase transitions)
{"type": "set_empathy", "value": 0.7}
Server -> Client:
{"type": "assistant_message", "content": "...", "phase": "...", "action": "..."}
{"type": "question", "data": {...}}
{"type": "sphere_update", "data": {...}}
{"type": "counterfactual", "data": {...}}
{"type": "intervention", "data": {...}}
{"type": "phase_change", "phase": "..."}
{"type": "belief_update", "data": {...}}
{"type": "error", "message": "..."}
"""
await websocket.accept()
if not session_manager.session_exists(session_id):
await websocket.send_json({"type": "error", "message": f"Session {session_id} not found"})
await websocket.close()
return
agent = session_manager.get_agent(session_id)
if agent is None:
await websocket.send_json({"type": "error", "message": "Agent not found"})
await websocket.close()
return
try:
while True:
data = await websocket.receive_json()
msg_type = data.get("type", "")
if msg_type == "user_message":
content = data.get("content", "")
message_type = data.get("message_type", "text")
# Build user input
user_input: Dict[str, Any] = {"answer": content}
if "answer_index" in data:
user_input["answer_index"] = data["answer_index"]
if "choice" in data:
user_input["choice"] = data["choice"]
# Store in history
session_manager.add_to_history(session_id, "user", content)
# Process step
result = agent.step(user_input)
# Send assistant message
if result.get("message"):
session_manager.add_to_history(session_id, "assistant", result["message"])
await websocket.send_json({
"type": "assistant_message",
"content": result["message"],
"phase": result.get("phase", ""),
})
# Send question if present
if result.get("question"):
await websocket.send_json({
"type": "question",
"data": result["question"],
})
# Send sphere update if present
if result.get("sphere_data"):
await websocket.send_json({
"type": "sphere_update",
"data": result["sphere_data"],
})
# Send counterfactual if present
if result.get("counterfactual"):
await websocket.send_json({
"type": "counterfactual",
"data": result["counterfactual"],
})
# Send intervention if present
if result.get("intervention"):
await websocket.send_json({
"type": "intervention",
"data": result["intervention"],
})
# Send phase change
await websocket.send_json({
"type": "phase_change",
"phase": result.get("phase", ""),
"progress": result.get("progress"),
"is_complete": result.get("is_complete", False),
})
# Send belief update
if result.get("tom_stats") or result.get("user_type_summary"):
await websocket.send_json({
"type": "belief_update",
"data": {
"tom_stats": result.get("tom_stats"),
"user_type": result.get("user_type_summary"),
"beliefs": agent.get_belief_summary(),
},
})
elif msg_type == "continue":
# Trigger phase transition
result = agent.step({})
await websocket.send_json({
"type": "phase_change",
"phase": result.get("phase", ""),
})
if result.get("message"):
await websocket.send_json({
"type": "assistant_message",
"content": result["message"],
"phase": result.get("phase", ""),
})
if result.get("sphere_data"):
await websocket.send_json({
"type": "sphere_update",
"data": result["sphere_data"],
})
if result.get("intervention"):
await websocket.send_json({
"type": "intervention",
"data": result["intervention"],
})
if result.get("counterfactual"):
await websocket.send_json({
"type": "counterfactual",
"data": result["counterfactual"],
})
elif msg_type == "set_empathy":
value = float(data.get("value", 0.5))
agent.set_empathy_dial(value)
await websocket.send_json({
"type": "empathy_updated",
"value": value,
})
except WebSocketDisconnect:
pass