Complete Step 2: WebSocket Conversation Bridge
Browse files- Implement ConversationService for managing active AI conversations
- Add REST API endpoints for conversation control (/api/conversations/*)
- Integrate ConversationManager with WebSocket broadcasting
- Create comprehensive test script for end-to-end validation
- Update WebSocket client for Gradio compatibility (remove Streamlit deps)
- Switch frontend framework from Streamlit to Gradio in architecture
- Fix Python 3.9+ type hints compatibility
- Add prominent Quick Demo section to PROJECT_STATE.md
- Successfully tested full 3-terminal pipeline (Ollama + FastAPI + WebSocket)
Real-time AI-to-AI conversation streaming now fully operational.
backend/api/main.py
CHANGED
|
@@ -14,9 +14,16 @@ Typical usage:
|
|
| 14 |
from fastapi import FastAPI, WebSocket
|
| 15 |
from fastapi.middleware.cors import CORSMiddleware
|
| 16 |
import uvicorn
|
|
|
|
| 17 |
|
| 18 |
-
# Import WebSocket endpoint
|
| 19 |
-
from .websockets.conversation_ws import websocket_endpoint
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
# Initialize FastAPI app
|
| 22 |
app = FastAPI(
|
|
@@ -28,12 +35,30 @@ app = FastAPI(
|
|
| 28 |
# Configure CORS
|
| 29 |
app.add_middleware(
|
| 30 |
CORSMiddleware,
|
| 31 |
-
allow_origins=[
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
allow_credentials=True,
|
| 33 |
allow_methods=["*"],
|
| 34 |
allow_headers=["*"],
|
| 35 |
)
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
@app.get("/")
|
| 39 |
async def root():
|
|
|
|
| 14 |
from fastapi import FastAPI, WebSocket
|
| 15 |
from fastapi.middleware.cors import CORSMiddleware
|
| 16 |
import uvicorn
|
| 17 |
+
import logging
|
| 18 |
|
| 19 |
+
# Import WebSocket endpoint and manager
|
| 20 |
+
from .websockets.conversation_ws import websocket_endpoint, manager
|
| 21 |
+
from .routes.conversations import router as conversations_router
|
| 22 |
+
from .services.conversation_service import initialize_conversation_service
|
| 23 |
+
|
| 24 |
+
# Setup logging
|
| 25 |
+
logging.basicConfig(level=logging.INFO)
|
| 26 |
+
logger = logging.getLogger(__name__)
|
| 27 |
|
| 28 |
# Initialize FastAPI app
|
| 29 |
app = FastAPI(
|
|
|
|
| 35 |
# Configure CORS
|
| 36 |
app.add_middleware(
|
| 37 |
CORSMiddleware,
|
| 38 |
+
allow_origins=[
|
| 39 |
+
"http://localhost:8501", # Streamlit (legacy)
|
| 40 |
+
"http://localhost:7860", # Gradio default port
|
| 41 |
+
"http://127.0.0.1:7860", # Gradio alternative
|
| 42 |
+
],
|
| 43 |
allow_credentials=True,
|
| 44 |
allow_methods=["*"],
|
| 45 |
allow_headers=["*"],
|
| 46 |
)
|
| 47 |
|
| 48 |
+
# Include API routes
|
| 49 |
+
app.include_router(conversations_router)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@app.on_event("startup")
|
| 53 |
+
async def startup_event():
|
| 54 |
+
"""Initialize services on startup."""
|
| 55 |
+
logger.info("Initializing AI Survey Simulator API...")
|
| 56 |
+
|
| 57 |
+
# Initialize conversation service with WebSocket manager
|
| 58 |
+
initialize_conversation_service(manager)
|
| 59 |
+
|
| 60 |
+
logger.info("API startup complete")
|
| 61 |
+
|
| 62 |
|
| 63 |
@app.get("/")
|
| 64 |
async def root():
|
backend/api/routes/conversations.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""REST API routes for conversation management.
|
| 2 |
+
|
| 3 |
+
This module provides HTTP endpoints for controlling AI-to-AI conversations,
|
| 4 |
+
managing personas, and retrieving conversation status information.
|
| 5 |
+
|
| 6 |
+
Routes:
|
| 7 |
+
POST /api/conversations/start - Start a new conversation
|
| 8 |
+
GET /api/conversations/{conversation_id}/status - Get conversation status
|
| 9 |
+
POST /api/conversations/{conversation_id}/stop - Stop a conversation
|
| 10 |
+
GET /api/conversations - List active conversations
|
| 11 |
+
GET /api/personas - List available personas
|
| 12 |
+
|
| 13 |
+
Example:
|
| 14 |
+
POST /api/conversations/start
|
| 15 |
+
{
|
| 16 |
+
"conversation_id": "conv_123",
|
| 17 |
+
"surveyor_persona_id": "surveyor_001",
|
| 18 |
+
"patient_persona_id": "patient_001"
|
| 19 |
+
}
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
from fastapi import APIRouter, HTTPException, BackgroundTasks
|
| 23 |
+
from pydantic import BaseModel, Field
|
| 24 |
+
from typing import Dict, List, Optional
|
| 25 |
+
import logging
|
| 26 |
+
import sys
|
| 27 |
+
from pathlib import Path
|
| 28 |
+
|
| 29 |
+
# Add backend to path for imports
|
| 30 |
+
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
| 31 |
+
|
| 32 |
+
from api.services.conversation_service import get_conversation_service
|
| 33 |
+
from core.persona_system import PersonaSystem
|
| 34 |
+
|
| 35 |
+
# Setup logging
|
| 36 |
+
logger = logging.getLogger(__name__)
|
| 37 |
+
|
| 38 |
+
# Create router
|
| 39 |
+
router = APIRouter(prefix="/api", tags=["conversations"])
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# Request/Response models
|
| 43 |
+
class StartConversationRequest(BaseModel):
|
| 44 |
+
"""Request model for starting a conversation."""
|
| 45 |
+
conversation_id: str = Field(..., description="Unique identifier for the conversation")
|
| 46 |
+
surveyor_persona_id: str = Field(..., description="ID of the surveyor persona")
|
| 47 |
+
patient_persona_id: str = Field(..., description="ID of the patient persona")
|
| 48 |
+
host: str = Field(default="http://localhost:11434", description="Ollama server host")
|
| 49 |
+
model: str = Field(default="llama2:7b", description="LLM model to use")
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class ConversationStatusResponse(BaseModel):
|
| 53 |
+
"""Response model for conversation status."""
|
| 54 |
+
conversation_id: str
|
| 55 |
+
status: str
|
| 56 |
+
surveyor_persona_id: str
|
| 57 |
+
patient_persona_id: str
|
| 58 |
+
created_at: str
|
| 59 |
+
message_count: int
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class PersonaResponse(BaseModel):
|
| 63 |
+
"""Response model for persona information."""
|
| 64 |
+
id: str
|
| 65 |
+
name: str
|
| 66 |
+
role: str
|
| 67 |
+
description: str
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class ErrorResponse(BaseModel):
|
| 71 |
+
"""Error response model."""
|
| 72 |
+
error: str
|
| 73 |
+
detail: Optional[str] = None
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# Initialize persona system
|
| 77 |
+
persona_system = PersonaSystem()
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
@router.post("/conversations/start")
|
| 81 |
+
async def start_conversation(request: StartConversationRequest) -> Dict[str, str]:
|
| 82 |
+
"""Start a new AI-to-AI conversation.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
request: Conversation start request with persona IDs
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
Dict with success message and conversation_id
|
| 89 |
+
|
| 90 |
+
Raises:
|
| 91 |
+
HTTPException: If conversation fails to start
|
| 92 |
+
"""
|
| 93 |
+
try:
|
| 94 |
+
service = get_conversation_service()
|
| 95 |
+
|
| 96 |
+
success = await service.start_conversation(
|
| 97 |
+
conversation_id=request.conversation_id,
|
| 98 |
+
surveyor_persona_id=request.surveyor_persona_id,
|
| 99 |
+
patient_persona_id=request.patient_persona_id,
|
| 100 |
+
host=request.host,
|
| 101 |
+
model=request.model
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
if success:
|
| 105 |
+
logger.info(f"Started conversation {request.conversation_id}")
|
| 106 |
+
return {
|
| 107 |
+
"message": "Conversation started successfully",
|
| 108 |
+
"conversation_id": request.conversation_id
|
| 109 |
+
}
|
| 110 |
+
else:
|
| 111 |
+
raise HTTPException(
|
| 112 |
+
status_code=400,
|
| 113 |
+
detail="Failed to start conversation"
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
except Exception as e:
|
| 117 |
+
logger.error(f"Error starting conversation: {e}")
|
| 118 |
+
raise HTTPException(
|
| 119 |
+
status_code=500,
|
| 120 |
+
detail=f"Internal error starting conversation: {str(e)}"
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
@router.get("/conversations/{conversation_id}/status")
|
| 125 |
+
async def get_conversation_status(conversation_id: str) -> ConversationStatusResponse:
|
| 126 |
+
"""Get the status of a specific conversation.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
conversation_id: Unique identifier of the conversation
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
ConversationStatusResponse with current status
|
| 133 |
+
|
| 134 |
+
Raises:
|
| 135 |
+
HTTPException: If conversation not found
|
| 136 |
+
"""
|
| 137 |
+
try:
|
| 138 |
+
service = get_conversation_service()
|
| 139 |
+
status = await service.get_conversation_status(conversation_id)
|
| 140 |
+
|
| 141 |
+
if status is None:
|
| 142 |
+
raise HTTPException(
|
| 143 |
+
status_code=404,
|
| 144 |
+
detail=f"Conversation {conversation_id} not found"
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
return ConversationStatusResponse(**status)
|
| 148 |
+
|
| 149 |
+
except HTTPException:
|
| 150 |
+
raise
|
| 151 |
+
except Exception as e:
|
| 152 |
+
logger.error(f"Error getting conversation status: {e}")
|
| 153 |
+
raise HTTPException(
|
| 154 |
+
status_code=500,
|
| 155 |
+
detail=f"Internal error retrieving status: {str(e)}"
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
@router.post("/conversations/{conversation_id}/stop")
|
| 160 |
+
async def stop_conversation(conversation_id: str) -> Dict[str, str]:
|
| 161 |
+
"""Stop an active conversation.
|
| 162 |
+
|
| 163 |
+
Args:
|
| 164 |
+
conversation_id: Unique identifier of the conversation
|
| 165 |
+
|
| 166 |
+
Returns:
|
| 167 |
+
Dict with success message
|
| 168 |
+
|
| 169 |
+
Raises:
|
| 170 |
+
HTTPException: If conversation not found or cannot be stopped
|
| 171 |
+
"""
|
| 172 |
+
try:
|
| 173 |
+
service = get_conversation_service()
|
| 174 |
+
success = await service.stop_conversation(conversation_id)
|
| 175 |
+
|
| 176 |
+
if success:
|
| 177 |
+
logger.info(f"Stopped conversation {conversation_id}")
|
| 178 |
+
return {
|
| 179 |
+
"message": "Conversation stopped successfully",
|
| 180 |
+
"conversation_id": conversation_id
|
| 181 |
+
}
|
| 182 |
+
else:
|
| 183 |
+
raise HTTPException(
|
| 184 |
+
status_code=404,
|
| 185 |
+
detail=f"Conversation {conversation_id} not found"
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
except HTTPException:
|
| 189 |
+
raise
|
| 190 |
+
except Exception as e:
|
| 191 |
+
logger.error(f"Error stopping conversation: {e}")
|
| 192 |
+
raise HTTPException(
|
| 193 |
+
status_code=500,
|
| 194 |
+
detail=f"Internal error stopping conversation: {str(e)}"
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
@router.get("/conversations")
|
| 199 |
+
async def list_conversations() -> Dict[str, Dict]:
|
| 200 |
+
"""List all active conversations.
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
Dict mapping conversation_id to status information
|
| 204 |
+
"""
|
| 205 |
+
try:
|
| 206 |
+
service = get_conversation_service()
|
| 207 |
+
conversations = await service.list_active_conversations()
|
| 208 |
+
return conversations
|
| 209 |
+
|
| 210 |
+
except Exception as e:
|
| 211 |
+
logger.error(f"Error listing conversations: {e}")
|
| 212 |
+
raise HTTPException(
|
| 213 |
+
status_code=500,
|
| 214 |
+
detail=f"Internal error listing conversations: {str(e)}"
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
@router.get("/personas")
|
| 219 |
+
async def list_personas() -> Dict[str, List[PersonaResponse]]:
|
| 220 |
+
"""List all available personas grouped by role.
|
| 221 |
+
|
| 222 |
+
Returns:
|
| 223 |
+
Dict with 'surveyors' and 'patients' lists of personas
|
| 224 |
+
"""
|
| 225 |
+
try:
|
| 226 |
+
# Get all personas
|
| 227 |
+
surveyors = persona_system.list_personas("surveyor")
|
| 228 |
+
patients = persona_system.list_personas("patient")
|
| 229 |
+
|
| 230 |
+
# Format response
|
| 231 |
+
result = {
|
| 232 |
+
"surveyors": [
|
| 233 |
+
PersonaResponse(
|
| 234 |
+
id=persona.get("id", ""),
|
| 235 |
+
name=persona.get("name", "Unknown"),
|
| 236 |
+
role="surveyor",
|
| 237 |
+
description=persona.get("description", "")
|
| 238 |
+
)
|
| 239 |
+
for persona in surveyors
|
| 240 |
+
],
|
| 241 |
+
"patients": [
|
| 242 |
+
PersonaResponse(
|
| 243 |
+
id=persona.get("id", ""),
|
| 244 |
+
name=persona.get("name", "Unknown"),
|
| 245 |
+
role="patient",
|
| 246 |
+
description=persona.get("description", "")
|
| 247 |
+
)
|
| 248 |
+
for persona in patients
|
| 249 |
+
]
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
return result
|
| 253 |
+
|
| 254 |
+
except Exception as e:
|
| 255 |
+
logger.error(f"Error listing personas: {e}")
|
| 256 |
+
raise HTTPException(
|
| 257 |
+
status_code=500,
|
| 258 |
+
detail=f"Internal error listing personas: {str(e)}"
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
@router.get("/health")
|
| 263 |
+
async def health_check() -> Dict[str, str]:
|
| 264 |
+
"""Health check endpoint for monitoring.
|
| 265 |
+
|
| 266 |
+
Returns:
|
| 267 |
+
Dict with service health status
|
| 268 |
+
"""
|
| 269 |
+
try:
|
| 270 |
+
# Check if conversation service is available
|
| 271 |
+
service = get_conversation_service()
|
| 272 |
+
|
| 273 |
+
# Check if persona system is working
|
| 274 |
+
personas = persona_system.list_personas("surveyor")
|
| 275 |
+
|
| 276 |
+
return {
|
| 277 |
+
"status": "healthy",
|
| 278 |
+
"service": "conversation_api",
|
| 279 |
+
"personas_loaded": str(len(personas))
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
except Exception as e:
|
| 283 |
+
logger.error(f"Health check failed: {e}")
|
| 284 |
+
return {
|
| 285 |
+
"status": "unhealthy",
|
| 286 |
+
"error": str(e)
|
| 287 |
+
}
|
backend/api/services/conversation_service.py
ADDED
|
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Conversation Service for managing active AI-to-AI conversations.
|
| 2 |
+
|
| 3 |
+
This service acts as the bridge between the WebSocket interface and the
|
| 4 |
+
ConversationManager. It handles the lifecycle of conversations, manages
|
| 5 |
+
active instances, and coordinates message streaming to connected clients.
|
| 6 |
+
|
| 7 |
+
Classes:
|
| 8 |
+
ConversationService: Main service for conversation management
|
| 9 |
+
ConversationInfo: Data class for conversation metadata
|
| 10 |
+
|
| 11 |
+
Example:
|
| 12 |
+
service = ConversationService(websocket_manager)
|
| 13 |
+
conversation_id = await service.start_conversation(
|
| 14 |
+
surveyor_id="surveyor_001",
|
| 15 |
+
patient_id="patient_001"
|
| 16 |
+
)
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import asyncio
|
| 20 |
+
import logging
|
| 21 |
+
from datetime import datetime
|
| 22 |
+
from typing import Dict, Optional, Set
|
| 23 |
+
from dataclasses import dataclass
|
| 24 |
+
from enum import Enum
|
| 25 |
+
import sys
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
|
| 28 |
+
# Add backend to path for imports
|
| 29 |
+
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
| 30 |
+
|
| 31 |
+
from core.conversation_manager import ConversationManager, ConversationState
|
| 32 |
+
from core.persona_system import PersonaSystem
|
| 33 |
+
from api.websockets.conversation_ws import ConnectionManager
|
| 34 |
+
|
| 35 |
+
# Setup logging
|
| 36 |
+
logger = logging.getLogger(__name__)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class ConversationStatus(Enum):
|
| 40 |
+
"""Status of managed conversations."""
|
| 41 |
+
STARTING = "starting"
|
| 42 |
+
RUNNING = "running"
|
| 43 |
+
PAUSED = "paused"
|
| 44 |
+
STOPPING = "stopping"
|
| 45 |
+
COMPLETED = "completed"
|
| 46 |
+
ERROR = "error"
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@dataclass
|
| 50 |
+
class ConversationInfo:
|
| 51 |
+
"""Information about an active conversation."""
|
| 52 |
+
conversation_id: str
|
| 53 |
+
surveyor_persona_id: str
|
| 54 |
+
patient_persona_id: str
|
| 55 |
+
status: ConversationStatus
|
| 56 |
+
created_at: datetime
|
| 57 |
+
message_count: int = 0
|
| 58 |
+
task: Optional[asyncio.Task] = None
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class ConversationService:
|
| 62 |
+
"""Service for managing AI-to-AI conversation instances.
|
| 63 |
+
|
| 64 |
+
This service coordinates between the ConversationManager and WebSocket
|
| 65 |
+
infrastructure to provide real-time conversation streaming to web clients.
|
| 66 |
+
|
| 67 |
+
Attributes:
|
| 68 |
+
websocket_manager: WebSocket connection manager for broadcasting
|
| 69 |
+
persona_system: Persona system for loading personas
|
| 70 |
+
active_conversations: Dict of active conversation instances
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
def __init__(self, websocket_manager: ConnectionManager):
|
| 74 |
+
"""Initialize conversation service.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
websocket_manager: WebSocket manager for message broadcasting
|
| 78 |
+
"""
|
| 79 |
+
self.websocket_manager = websocket_manager
|
| 80 |
+
self.persona_system = PersonaSystem()
|
| 81 |
+
self.active_conversations: Dict[str, ConversationInfo] = {}
|
| 82 |
+
|
| 83 |
+
async def start_conversation(self,
|
| 84 |
+
conversation_id: str,
|
| 85 |
+
surveyor_persona_id: str,
|
| 86 |
+
patient_persona_id: str,
|
| 87 |
+
host: str = "http://localhost:11434",
|
| 88 |
+
model: str = "llama2:7b") -> bool:
|
| 89 |
+
"""Start a new AI-to-AI conversation.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
conversation_id: Unique identifier for the conversation
|
| 93 |
+
surveyor_persona_id: ID of the surveyor persona
|
| 94 |
+
patient_persona_id: ID of the patient persona
|
| 95 |
+
host: Ollama server host
|
| 96 |
+
model: LLM model to use
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
True if conversation started successfully
|
| 100 |
+
"""
|
| 101 |
+
if conversation_id in self.active_conversations:
|
| 102 |
+
logger.warning(f"Conversation {conversation_id} already exists")
|
| 103 |
+
return False
|
| 104 |
+
|
| 105 |
+
try:
|
| 106 |
+
# Load personas
|
| 107 |
+
surveyors = self.persona_system.list_personas("surveyor")
|
| 108 |
+
patients = self.persona_system.list_personas("patient")
|
| 109 |
+
|
| 110 |
+
surveyor_persona = next((p for p in surveyors if p.get("id") == surveyor_persona_id), None)
|
| 111 |
+
patient_persona = next((p for p in patients if p.get("id") == patient_persona_id), None)
|
| 112 |
+
|
| 113 |
+
if not surveyor_persona or not patient_persona:
|
| 114 |
+
await self._send_error(conversation_id, "Invalid persona IDs")
|
| 115 |
+
return False
|
| 116 |
+
|
| 117 |
+
# Create conversation info
|
| 118 |
+
conv_info = ConversationInfo(
|
| 119 |
+
conversation_id=conversation_id,
|
| 120 |
+
surveyor_persona_id=surveyor_persona_id,
|
| 121 |
+
patient_persona_id=patient_persona_id,
|
| 122 |
+
status=ConversationStatus.STARTING,
|
| 123 |
+
created_at=datetime.now()
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
self.active_conversations[conversation_id] = conv_info
|
| 127 |
+
|
| 128 |
+
# Send status update
|
| 129 |
+
await self._send_status_update(conversation_id, ConversationStatus.STARTING)
|
| 130 |
+
|
| 131 |
+
# Create and start conversation manager
|
| 132 |
+
manager = ConversationManager(
|
| 133 |
+
surveyor_persona=surveyor_persona,
|
| 134 |
+
patient_persona=patient_persona,
|
| 135 |
+
host=host,
|
| 136 |
+
model=model
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
# Start conversation streaming task
|
| 140 |
+
conv_info.task = asyncio.create_task(
|
| 141 |
+
self._stream_conversation(conversation_id, manager)
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
conv_info.status = ConversationStatus.RUNNING
|
| 145 |
+
await self._send_status_update(conversation_id, ConversationStatus.RUNNING)
|
| 146 |
+
|
| 147 |
+
logger.info(f"Started conversation {conversation_id}")
|
| 148 |
+
return True
|
| 149 |
+
|
| 150 |
+
except Exception as e:
|
| 151 |
+
logger.error(f"Failed to start conversation {conversation_id}: {e}")
|
| 152 |
+
await self._send_error(conversation_id, f"Failed to start conversation: {str(e)}")
|
| 153 |
+
|
| 154 |
+
# Clean up
|
| 155 |
+
if conversation_id in self.active_conversations:
|
| 156 |
+
del self.active_conversations[conversation_id]
|
| 157 |
+
|
| 158 |
+
return False
|
| 159 |
+
|
| 160 |
+
async def stop_conversation(self, conversation_id: str) -> bool:
|
| 161 |
+
"""Stop an active conversation.
|
| 162 |
+
|
| 163 |
+
Args:
|
| 164 |
+
conversation_id: ID of conversation to stop
|
| 165 |
+
|
| 166 |
+
Returns:
|
| 167 |
+
True if conversation stopped successfully
|
| 168 |
+
"""
|
| 169 |
+
if conversation_id not in self.active_conversations:
|
| 170 |
+
logger.warning(f"Conversation {conversation_id} not found")
|
| 171 |
+
return False
|
| 172 |
+
|
| 173 |
+
conv_info = self.active_conversations[conversation_id]
|
| 174 |
+
|
| 175 |
+
try:
|
| 176 |
+
conv_info.status = ConversationStatus.STOPPING
|
| 177 |
+
await self._send_status_update(conversation_id, ConversationStatus.STOPPING)
|
| 178 |
+
|
| 179 |
+
# Cancel the conversation task
|
| 180 |
+
if conv_info.task and not conv_info.task.done():
|
| 181 |
+
conv_info.task.cancel()
|
| 182 |
+
try:
|
| 183 |
+
await conv_info.task
|
| 184 |
+
except asyncio.CancelledError:
|
| 185 |
+
pass
|
| 186 |
+
|
| 187 |
+
# Update status and clean up
|
| 188 |
+
conv_info.status = ConversationStatus.COMPLETED
|
| 189 |
+
await self._send_status_update(conversation_id, ConversationStatus.COMPLETED)
|
| 190 |
+
|
| 191 |
+
del self.active_conversations[conversation_id]
|
| 192 |
+
logger.info(f"Stopped conversation {conversation_id}")
|
| 193 |
+
return True
|
| 194 |
+
|
| 195 |
+
except Exception as e:
|
| 196 |
+
logger.error(f"Error stopping conversation {conversation_id}: {e}")
|
| 197 |
+
conv_info.status = ConversationStatus.ERROR
|
| 198 |
+
await self._send_error(conversation_id, f"Error stopping conversation: {str(e)}")
|
| 199 |
+
return False
|
| 200 |
+
|
| 201 |
+
async def get_conversation_status(self, conversation_id: str) -> Optional[Dict]:
|
| 202 |
+
"""Get status of a conversation.
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
conversation_id: ID of conversation
|
| 206 |
+
|
| 207 |
+
Returns:
|
| 208 |
+
Dict with conversation status or None if not found
|
| 209 |
+
"""
|
| 210 |
+
if conversation_id not in self.active_conversations:
|
| 211 |
+
return None
|
| 212 |
+
|
| 213 |
+
conv_info = self.active_conversations[conversation_id]
|
| 214 |
+
return {
|
| 215 |
+
"conversation_id": conversation_id,
|
| 216 |
+
"status": conv_info.status.value,
|
| 217 |
+
"surveyor_persona_id": conv_info.surveyor_persona_id,
|
| 218 |
+
"patient_persona_id": conv_info.patient_persona_id,
|
| 219 |
+
"created_at": conv_info.created_at.isoformat(),
|
| 220 |
+
"message_count": conv_info.message_count
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
async def list_active_conversations(self) -> Dict[str, Dict]:
|
| 224 |
+
"""List all active conversations.
|
| 225 |
+
|
| 226 |
+
Returns:
|
| 227 |
+
Dict mapping conversation_id to status info
|
| 228 |
+
"""
|
| 229 |
+
result = {}
|
| 230 |
+
for conv_id, conv_info in self.active_conversations.items():
|
| 231 |
+
result[conv_id] = {
|
| 232 |
+
"status": conv_info.status.value,
|
| 233 |
+
"surveyor_persona_id": conv_info.surveyor_persona_id,
|
| 234 |
+
"patient_persona_id": conv_info.patient_persona_id,
|
| 235 |
+
"created_at": conv_info.created_at.isoformat(),
|
| 236 |
+
"message_count": conv_info.message_count
|
| 237 |
+
}
|
| 238 |
+
return result
|
| 239 |
+
|
| 240 |
+
async def _stream_conversation(self, conversation_id: str, manager: ConversationManager):
|
| 241 |
+
"""Stream conversation messages to WebSocket clients.
|
| 242 |
+
|
| 243 |
+
Args:
|
| 244 |
+
conversation_id: ID of the conversation
|
| 245 |
+
manager: ConversationManager instance to stream from
|
| 246 |
+
"""
|
| 247 |
+
conv_info = self.active_conversations.get(conversation_id)
|
| 248 |
+
if not conv_info:
|
| 249 |
+
return
|
| 250 |
+
|
| 251 |
+
try:
|
| 252 |
+
async for message in manager.conduct_conversation():
|
| 253 |
+
# Check if conversation was stopped
|
| 254 |
+
if conversation_id not in self.active_conversations:
|
| 255 |
+
break
|
| 256 |
+
|
| 257 |
+
# Update message count
|
| 258 |
+
conv_info.message_count += 1
|
| 259 |
+
|
| 260 |
+
# Add conversation metadata
|
| 261 |
+
websocket_message = {
|
| 262 |
+
"type": "conversation_message",
|
| 263 |
+
"conversation_id": conversation_id,
|
| 264 |
+
**message
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
# Send to all connected clients
|
| 268 |
+
await self.websocket_manager.send_to_conversation(
|
| 269 |
+
conversation_id, websocket_message
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
logger.debug(f"Streamed message {conv_info.message_count} for conversation {conversation_id}")
|
| 273 |
+
|
| 274 |
+
except asyncio.CancelledError:
|
| 275 |
+
logger.info(f"Conversation {conversation_id} streaming cancelled")
|
| 276 |
+
raise
|
| 277 |
+
except Exception as e:
|
| 278 |
+
logger.error(f"Error streaming conversation {conversation_id}: {e}")
|
| 279 |
+
conv_info.status = ConversationStatus.ERROR
|
| 280 |
+
await self._send_error(conversation_id, f"Streaming error: {str(e)}")
|
| 281 |
+
finally:
|
| 282 |
+
# Clean up conversation manager
|
| 283 |
+
try:
|
| 284 |
+
await manager.close()
|
| 285 |
+
except:
|
| 286 |
+
pass
|
| 287 |
+
|
| 288 |
+
# Mark as completed if not already in error state
|
| 289 |
+
if conv_info.status != ConversationStatus.ERROR:
|
| 290 |
+
conv_info.status = ConversationStatus.COMPLETED
|
| 291 |
+
await self._send_status_update(conversation_id, ConversationStatus.COMPLETED)
|
| 292 |
+
|
| 293 |
+
async def _send_status_update(self, conversation_id: str, status: ConversationStatus):
|
| 294 |
+
"""Send conversation status update to clients.
|
| 295 |
+
|
| 296 |
+
Args:
|
| 297 |
+
conversation_id: ID of the conversation
|
| 298 |
+
status: New conversation status
|
| 299 |
+
"""
|
| 300 |
+
message = {
|
| 301 |
+
"type": "conversation_status",
|
| 302 |
+
"conversation_id": conversation_id,
|
| 303 |
+
"status": status.value,
|
| 304 |
+
"timestamp": datetime.now().isoformat()
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
await self.websocket_manager.send_to_conversation(conversation_id, message)
|
| 308 |
+
|
| 309 |
+
async def _send_error(self, conversation_id: str, error_message: str):
|
| 310 |
+
"""Send error message to clients.
|
| 311 |
+
|
| 312 |
+
Args:
|
| 313 |
+
conversation_id: ID of the conversation
|
| 314 |
+
error_message: Error description
|
| 315 |
+
"""
|
| 316 |
+
message = {
|
| 317 |
+
"type": "conversation_error",
|
| 318 |
+
"conversation_id": conversation_id,
|
| 319 |
+
"error": error_message,
|
| 320 |
+
"timestamp": datetime.now().isoformat()
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
+
await self.websocket_manager.send_to_conversation(conversation_id, message)
|
| 324 |
+
|
| 325 |
+
async def cleanup(self):
|
| 326 |
+
"""Clean up all active conversations."""
|
| 327 |
+
for conversation_id in list(self.active_conversations.keys()):
|
| 328 |
+
await self.stop_conversation(conversation_id)
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
# Global service instance (initialized in main.py)
|
| 332 |
+
conversation_service: Optional[ConversationService] = None
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def get_conversation_service() -> ConversationService:
|
| 336 |
+
"""Get the global conversation service instance.
|
| 337 |
+
|
| 338 |
+
Returns:
|
| 339 |
+
ConversationService instance
|
| 340 |
+
|
| 341 |
+
Raises:
|
| 342 |
+
RuntimeError: If service not initialized
|
| 343 |
+
"""
|
| 344 |
+
if conversation_service is None:
|
| 345 |
+
raise RuntimeError("ConversationService not initialized")
|
| 346 |
+
return conversation_service
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
def initialize_conversation_service(websocket_manager: ConnectionManager):
|
| 350 |
+
"""Initialize the global conversation service.
|
| 351 |
+
|
| 352 |
+
Args:
|
| 353 |
+
websocket_manager: WebSocket connection manager
|
| 354 |
+
"""
|
| 355 |
+
global conversation_service
|
| 356 |
+
conversation_service = ConversationService(websocket_manager)
|
| 357 |
+
logger.info("ConversationService initialized")
|
backend/api/websockets/conversation_ws.py
CHANGED
|
@@ -174,9 +174,10 @@ def validate_message(data: dict) -> bool:
|
|
| 174 |
# Validate message types
|
| 175 |
valid_types = [
|
| 176 |
"conversation_message",
|
| 177 |
-
"typing_indicator",
|
| 178 |
"conversation_control",
|
| 179 |
-
"heartbeat"
|
|
|
|
| 180 |
]
|
| 181 |
|
| 182 |
if data["type"] not in valid_types:
|
|
@@ -206,7 +207,11 @@ async def handle_message(data: dict, conversation_id: str):
|
|
| 206 |
elif message_type == "conversation_control":
|
| 207 |
# Handle conversation control (start, pause, stop)
|
| 208 |
await handle_conversation_control(data, conversation_id)
|
| 209 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
elif message_type == "heartbeat":
|
| 211 |
# Respond to heartbeat
|
| 212 |
await manager.send_to_conversation(conversation_id, {
|
|
@@ -217,26 +222,114 @@ async def handle_message(data: dict, conversation_id: str):
|
|
| 217 |
|
| 218 |
async def handle_conversation_control(data: dict, conversation_id: str):
|
| 219 |
"""Handle conversation control messages.
|
| 220 |
-
|
| 221 |
Args:
|
| 222 |
data: Control message data
|
| 223 |
conversation_id: Target conversation ID
|
| 224 |
"""
|
| 225 |
control_action = data.get("action")
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
await manager.send_to_conversation(conversation_id, {
|
| 230 |
-
"type": "
|
| 231 |
-
"
|
| 232 |
-
"
|
| 233 |
-
|
| 234 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
})
|
| 236 |
-
|
| 237 |
-
logger.info(f"Conversation {conversation_id} {control_action}ed")
|
| 238 |
-
else:
|
| 239 |
-
logger.warning(f"Unknown control action: {control_action}")
|
| 240 |
|
| 241 |
|
| 242 |
# Export the manager for use in other modules
|
|
|
|
| 174 |
# Validate message types
|
| 175 |
valid_types = [
|
| 176 |
"conversation_message",
|
| 177 |
+
"typing_indicator",
|
| 178 |
"conversation_control",
|
| 179 |
+
"heartbeat",
|
| 180 |
+
"start_conversation" # New message type
|
| 181 |
]
|
| 182 |
|
| 183 |
if data["type"] not in valid_types:
|
|
|
|
| 207 |
elif message_type == "conversation_control":
|
| 208 |
# Handle conversation control (start, pause, stop)
|
| 209 |
await handle_conversation_control(data, conversation_id)
|
| 210 |
+
|
| 211 |
+
elif message_type == "start_conversation":
|
| 212 |
+
# Handle starting a new conversation
|
| 213 |
+
await handle_start_conversation(data, conversation_id)
|
| 214 |
+
|
| 215 |
elif message_type == "heartbeat":
|
| 216 |
# Respond to heartbeat
|
| 217 |
await manager.send_to_conversation(conversation_id, {
|
|
|
|
| 222 |
|
| 223 |
async def handle_conversation_control(data: dict, conversation_id: str):
|
| 224 |
"""Handle conversation control messages.
|
| 225 |
+
|
| 226 |
Args:
|
| 227 |
data: Control message data
|
| 228 |
conversation_id: Target conversation ID
|
| 229 |
"""
|
| 230 |
control_action = data.get("action")
|
| 231 |
+
|
| 232 |
+
try:
|
| 233 |
+
# Import here to avoid circular imports
|
| 234 |
+
from ..services.conversation_service import get_conversation_service
|
| 235 |
+
service = get_conversation_service()
|
| 236 |
+
|
| 237 |
+
if control_action == "stop":
|
| 238 |
+
success = await service.stop_conversation(conversation_id)
|
| 239 |
+
if success:
|
| 240 |
+
await manager.send_to_conversation(conversation_id, {
|
| 241 |
+
"type": "conversation_control",
|
| 242 |
+
"action": "stop",
|
| 243 |
+
"conversation_id": conversation_id,
|
| 244 |
+
"timestamp": datetime.now().isoformat(),
|
| 245 |
+
"message": "Conversation stopped"
|
| 246 |
+
})
|
| 247 |
+
else:
|
| 248 |
+
await manager.send_to_conversation(conversation_id, {
|
| 249 |
+
"type": "error",
|
| 250 |
+
"error": "Failed to stop conversation",
|
| 251 |
+
"timestamp": datetime.now().isoformat()
|
| 252 |
+
})
|
| 253 |
+
|
| 254 |
+
elif control_action in ["pause", "resume"]:
|
| 255 |
+
# For now, just broadcast the action (pause/resume not fully implemented)
|
| 256 |
+
await manager.send_to_conversation(conversation_id, {
|
| 257 |
+
"type": "conversation_control",
|
| 258 |
+
"action": control_action,
|
| 259 |
+
"conversation_id": conversation_id,
|
| 260 |
+
"timestamp": datetime.now().isoformat(),
|
| 261 |
+
"message": f"Conversation {control_action}d"
|
| 262 |
+
})
|
| 263 |
+
logger.info(f"Conversation {conversation_id} {control_action}d")
|
| 264 |
+
|
| 265 |
+
else:
|
| 266 |
+
logger.warning(f"Unknown control action: {control_action}")
|
| 267 |
+
await manager.send_to_conversation(conversation_id, {
|
| 268 |
+
"type": "error",
|
| 269 |
+
"error": f"Unknown control action: {control_action}",
|
| 270 |
+
"timestamp": datetime.now().isoformat()
|
| 271 |
+
})
|
| 272 |
+
|
| 273 |
+
except Exception as e:
|
| 274 |
+
logger.error(f"Error handling conversation control: {e}")
|
| 275 |
await manager.send_to_conversation(conversation_id, {
|
| 276 |
+
"type": "error",
|
| 277 |
+
"error": f"Control error: {str(e)}",
|
| 278 |
+
"timestamp": datetime.now().isoformat()
|
| 279 |
+
})
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
async def handle_start_conversation(data: dict, conversation_id: str):
|
| 283 |
+
"""Handle starting a new conversation via WebSocket.
|
| 284 |
+
|
| 285 |
+
Args:
|
| 286 |
+
data: Start conversation message data
|
| 287 |
+
conversation_id: Target conversation ID
|
| 288 |
+
"""
|
| 289 |
+
try:
|
| 290 |
+
# Import here to avoid circular imports
|
| 291 |
+
from ..services.conversation_service import get_conversation_service
|
| 292 |
+
service = get_conversation_service()
|
| 293 |
+
|
| 294 |
+
# Extract required fields
|
| 295 |
+
surveyor_persona_id = data.get("surveyor_persona_id")
|
| 296 |
+
patient_persona_id = data.get("patient_persona_id")
|
| 297 |
+
host = data.get("host", "http://localhost:11434")
|
| 298 |
+
model = data.get("model", "llama2:7b")
|
| 299 |
+
|
| 300 |
+
if not surveyor_persona_id or not patient_persona_id:
|
| 301 |
+
await manager.send_to_conversation(conversation_id, {
|
| 302 |
+
"type": "error",
|
| 303 |
+
"error": "Missing required persona IDs",
|
| 304 |
+
"timestamp": datetime.now().isoformat()
|
| 305 |
+
})
|
| 306 |
+
return
|
| 307 |
+
|
| 308 |
+
# Start the conversation
|
| 309 |
+
success = await service.start_conversation(
|
| 310 |
+
conversation_id=conversation_id,
|
| 311 |
+
surveyor_persona_id=surveyor_persona_id,
|
| 312 |
+
patient_persona_id=patient_persona_id,
|
| 313 |
+
host=host,
|
| 314 |
+
model=model
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
if success:
|
| 318 |
+
logger.info(f"Started conversation {conversation_id} via WebSocket")
|
| 319 |
+
else:
|
| 320 |
+
await manager.send_to_conversation(conversation_id, {
|
| 321 |
+
"type": "error",
|
| 322 |
+
"error": "Failed to start conversation",
|
| 323 |
+
"timestamp": datetime.now().isoformat()
|
| 324 |
+
})
|
| 325 |
+
|
| 326 |
+
except Exception as e:
|
| 327 |
+
logger.error(f"Error starting conversation via WebSocket: {e}")
|
| 328 |
+
await manager.send_to_conversation(conversation_id, {
|
| 329 |
+
"type": "error",
|
| 330 |
+
"error": f"Start error: {str(e)}",
|
| 331 |
+
"timestamp": datetime.now().isoformat()
|
| 332 |
})
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
|
| 334 |
|
| 335 |
# Export the manager for use in other modules
|
backend/core/persona_system.py
CHANGED
|
@@ -10,7 +10,7 @@ both survey interviewers and patient respondents. It provides:
|
|
| 10 |
Classes:
|
| 11 |
PersonaSystem: Main persona management system
|
| 12 |
PersonaPromptBuilder: Constructs prompts with persona context
|
| 13 |
-
|
| 14 |
Example:
|
| 15 |
system = PersonaSystem()
|
| 16 |
prompt = system.build_conversation_prompt(
|
|
@@ -20,8 +20,10 @@ Example:
|
|
| 20 |
)
|
| 21 |
"""
|
| 22 |
|
|
|
|
|
|
|
| 23 |
import yaml
|
| 24 |
-
from typing import Dict, List, Optional, Any
|
| 25 |
from pathlib import Path
|
| 26 |
from datetime import datetime
|
| 27 |
import logging
|
|
@@ -218,7 +220,7 @@ class PersonaSystem:
|
|
| 218 |
conversation_history: List[Dict[str, Any]] = None,
|
| 219 |
current_context: Dict[str, Any] = None,
|
| 220 |
user_prompt: str = "",
|
| 221 |
-
max_history: int = 10) ->
|
| 222 |
"""Build complete prompt for conversation generation.
|
| 223 |
|
| 224 |
Args:
|
|
|
|
| 10 |
Classes:
|
| 11 |
PersonaSystem: Main persona management system
|
| 12 |
PersonaPromptBuilder: Constructs prompts with persona context
|
| 13 |
+
|
| 14 |
Example:
|
| 15 |
system = PersonaSystem()
|
| 16 |
prompt = system.build_conversation_prompt(
|
|
|
|
| 20 |
)
|
| 21 |
"""
|
| 22 |
|
| 23 |
+
from __future__ import annotations
|
| 24 |
+
|
| 25 |
import yaml
|
| 26 |
+
from typing import Dict, List, Optional, Any, Tuple
|
| 27 |
from pathlib import Path
|
| 28 |
from datetime import datetime
|
| 29 |
import logging
|
|
|
|
| 220 |
conversation_history: List[Dict[str, Any]] = None,
|
| 221 |
current_context: Dict[str, Any] = None,
|
| 222 |
user_prompt: str = "",
|
| 223 |
+
max_history: int = 10) -> Tuple[str, str]:
|
| 224 |
"""Build complete prompt for conversation generation.
|
| 225 |
|
| 226 |
Args:
|