MikelWL commited on
Commit
693f75f
·
1 Parent(s): 221c154

Complete Step 2: WebSocket Conversation Bridge

Browse files

- Implement ConversationService for managing active AI conversations
- Add REST API endpoints for conversation control (/api/conversations/*)
- Integrate ConversationManager with WebSocket broadcasting
- Create comprehensive test script for end-to-end validation
- Update WebSocket client for Gradio compatibility (remove Streamlit deps)
- Switch frontend framework from Streamlit to Gradio in architecture
- Fix Python 3.9+ type hints compatibility
- Add prominent Quick Demo section to PROJECT_STATE.md
- Successfully tested full 3-terminal pipeline (Ollama + FastAPI + WebSocket)

Real-time AI-to-AI conversation streaming now fully operational.

backend/api/main.py CHANGED
@@ -14,9 +14,16 @@ Typical usage:
14
  from fastapi import FastAPI, WebSocket
15
  from fastapi.middleware.cors import CORSMiddleware
16
  import uvicorn
 
17
 
18
- # Import WebSocket endpoint
19
- from .websockets.conversation_ws import websocket_endpoint
 
 
 
 
 
 
20
 
21
  # Initialize FastAPI app
22
  app = FastAPI(
@@ -28,12 +35,30 @@ app = FastAPI(
28
  # Configure CORS
29
  app.add_middleware(
30
  CORSMiddleware,
31
- allow_origins=["http://localhost:8501"], # Streamlit frontend
 
 
 
 
32
  allow_credentials=True,
33
  allow_methods=["*"],
34
  allow_headers=["*"],
35
  )
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  @app.get("/")
39
  async def root():
 
14
  from fastapi import FastAPI, WebSocket
15
  from fastapi.middleware.cors import CORSMiddleware
16
  import uvicorn
17
+ import logging
18
 
19
+ # Import WebSocket endpoint and manager
20
+ from .websockets.conversation_ws import websocket_endpoint, manager
21
+ from .routes.conversations import router as conversations_router
22
+ from .services.conversation_service import initialize_conversation_service
23
+
24
+ # Setup logging
25
+ logging.basicConfig(level=logging.INFO)
26
+ logger = logging.getLogger(__name__)
27
 
28
  # Initialize FastAPI app
29
  app = FastAPI(
 
35
  # Configure CORS
36
  app.add_middleware(
37
  CORSMiddleware,
38
+ allow_origins=[
39
+ "http://localhost:8501", # Streamlit (legacy)
40
+ "http://localhost:7860", # Gradio default port
41
+ "http://127.0.0.1:7860", # Gradio alternative
42
+ ],
43
  allow_credentials=True,
44
  allow_methods=["*"],
45
  allow_headers=["*"],
46
  )
47
 
48
+ # Include API routes
49
+ app.include_router(conversations_router)
50
+
51
+
52
+ @app.on_event("startup")
53
+ async def startup_event():
54
+ """Initialize services on startup."""
55
+ logger.info("Initializing AI Survey Simulator API...")
56
+
57
+ # Initialize conversation service with WebSocket manager
58
+ initialize_conversation_service(manager)
59
+
60
+ logger.info("API startup complete")
61
+
62
 
63
  @app.get("/")
64
  async def root():
backend/api/routes/conversations.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """REST API routes for conversation management.
2
+
3
+ This module provides HTTP endpoints for controlling AI-to-AI conversations,
4
+ managing personas, and retrieving conversation status information.
5
+
6
+ Routes:
7
+ POST /api/conversations/start - Start a new conversation
8
+ GET /api/conversations/{conversation_id}/status - Get conversation status
9
+ POST /api/conversations/{conversation_id}/stop - Stop a conversation
10
+ GET /api/conversations - List active conversations
11
+ GET /api/personas - List available personas
12
+
13
+ Example:
14
+ POST /api/conversations/start
15
+ {
16
+ "conversation_id": "conv_123",
17
+ "surveyor_persona_id": "surveyor_001",
18
+ "patient_persona_id": "patient_001"
19
+ }
20
+ """
21
+
22
+ from fastapi import APIRouter, HTTPException, BackgroundTasks
23
+ from pydantic import BaseModel, Field
24
+ from typing import Dict, List, Optional
25
+ import logging
26
+ import sys
27
+ from pathlib import Path
28
+
29
+ # Add backend to path for imports
30
+ sys.path.insert(0, str(Path(__file__).parent.parent.parent))
31
+
32
+ from api.services.conversation_service import get_conversation_service
33
+ from core.persona_system import PersonaSystem
34
+
35
+ # Setup logging
36
+ logger = logging.getLogger(__name__)
37
+
38
+ # Create router
39
+ router = APIRouter(prefix="/api", tags=["conversations"])
40
+
41
+
42
+ # Request/Response models
43
+ class StartConversationRequest(BaseModel):
44
+ """Request model for starting a conversation."""
45
+ conversation_id: str = Field(..., description="Unique identifier for the conversation")
46
+ surveyor_persona_id: str = Field(..., description="ID of the surveyor persona")
47
+ patient_persona_id: str = Field(..., description="ID of the patient persona")
48
+ host: str = Field(default="http://localhost:11434", description="Ollama server host")
49
+ model: str = Field(default="llama2:7b", description="LLM model to use")
50
+
51
+
52
+ class ConversationStatusResponse(BaseModel):
53
+ """Response model for conversation status."""
54
+ conversation_id: str
55
+ status: str
56
+ surveyor_persona_id: str
57
+ patient_persona_id: str
58
+ created_at: str
59
+ message_count: int
60
+
61
+
62
+ class PersonaResponse(BaseModel):
63
+ """Response model for persona information."""
64
+ id: str
65
+ name: str
66
+ role: str
67
+ description: str
68
+
69
+
70
+ class ErrorResponse(BaseModel):
71
+ """Error response model."""
72
+ error: str
73
+ detail: Optional[str] = None
74
+
75
+
76
+ # Initialize persona system
77
+ persona_system = PersonaSystem()
78
+
79
+
80
+ @router.post("/conversations/start")
81
+ async def start_conversation(request: StartConversationRequest) -> Dict[str, str]:
82
+ """Start a new AI-to-AI conversation.
83
+
84
+ Args:
85
+ request: Conversation start request with persona IDs
86
+
87
+ Returns:
88
+ Dict with success message and conversation_id
89
+
90
+ Raises:
91
+ HTTPException: If conversation fails to start
92
+ """
93
+ try:
94
+ service = get_conversation_service()
95
+
96
+ success = await service.start_conversation(
97
+ conversation_id=request.conversation_id,
98
+ surveyor_persona_id=request.surveyor_persona_id,
99
+ patient_persona_id=request.patient_persona_id,
100
+ host=request.host,
101
+ model=request.model
102
+ )
103
+
104
+ if success:
105
+ logger.info(f"Started conversation {request.conversation_id}")
106
+ return {
107
+ "message": "Conversation started successfully",
108
+ "conversation_id": request.conversation_id
109
+ }
110
+ else:
111
+ raise HTTPException(
112
+ status_code=400,
113
+ detail="Failed to start conversation"
114
+ )
115
+
116
+ except Exception as e:
117
+ logger.error(f"Error starting conversation: {e}")
118
+ raise HTTPException(
119
+ status_code=500,
120
+ detail=f"Internal error starting conversation: {str(e)}"
121
+ )
122
+
123
+
124
+ @router.get("/conversations/{conversation_id}/status")
125
+ async def get_conversation_status(conversation_id: str) -> ConversationStatusResponse:
126
+ """Get the status of a specific conversation.
127
+
128
+ Args:
129
+ conversation_id: Unique identifier of the conversation
130
+
131
+ Returns:
132
+ ConversationStatusResponse with current status
133
+
134
+ Raises:
135
+ HTTPException: If conversation not found
136
+ """
137
+ try:
138
+ service = get_conversation_service()
139
+ status = await service.get_conversation_status(conversation_id)
140
+
141
+ if status is None:
142
+ raise HTTPException(
143
+ status_code=404,
144
+ detail=f"Conversation {conversation_id} not found"
145
+ )
146
+
147
+ return ConversationStatusResponse(**status)
148
+
149
+ except HTTPException:
150
+ raise
151
+ except Exception as e:
152
+ logger.error(f"Error getting conversation status: {e}")
153
+ raise HTTPException(
154
+ status_code=500,
155
+ detail=f"Internal error retrieving status: {str(e)}"
156
+ )
157
+
158
+
159
+ @router.post("/conversations/{conversation_id}/stop")
160
+ async def stop_conversation(conversation_id: str) -> Dict[str, str]:
161
+ """Stop an active conversation.
162
+
163
+ Args:
164
+ conversation_id: Unique identifier of the conversation
165
+
166
+ Returns:
167
+ Dict with success message
168
+
169
+ Raises:
170
+ HTTPException: If conversation not found or cannot be stopped
171
+ """
172
+ try:
173
+ service = get_conversation_service()
174
+ success = await service.stop_conversation(conversation_id)
175
+
176
+ if success:
177
+ logger.info(f"Stopped conversation {conversation_id}")
178
+ return {
179
+ "message": "Conversation stopped successfully",
180
+ "conversation_id": conversation_id
181
+ }
182
+ else:
183
+ raise HTTPException(
184
+ status_code=404,
185
+ detail=f"Conversation {conversation_id} not found"
186
+ )
187
+
188
+ except HTTPException:
189
+ raise
190
+ except Exception as e:
191
+ logger.error(f"Error stopping conversation: {e}")
192
+ raise HTTPException(
193
+ status_code=500,
194
+ detail=f"Internal error stopping conversation: {str(e)}"
195
+ )
196
+
197
+
198
+ @router.get("/conversations")
199
+ async def list_conversations() -> Dict[str, Dict]:
200
+ """List all active conversations.
201
+
202
+ Returns:
203
+ Dict mapping conversation_id to status information
204
+ """
205
+ try:
206
+ service = get_conversation_service()
207
+ conversations = await service.list_active_conversations()
208
+ return conversations
209
+
210
+ except Exception as e:
211
+ logger.error(f"Error listing conversations: {e}")
212
+ raise HTTPException(
213
+ status_code=500,
214
+ detail=f"Internal error listing conversations: {str(e)}"
215
+ )
216
+
217
+
218
+ @router.get("/personas")
219
+ async def list_personas() -> Dict[str, List[PersonaResponse]]:
220
+ """List all available personas grouped by role.
221
+
222
+ Returns:
223
+ Dict with 'surveyors' and 'patients' lists of personas
224
+ """
225
+ try:
226
+ # Get all personas
227
+ surveyors = persona_system.list_personas("surveyor")
228
+ patients = persona_system.list_personas("patient")
229
+
230
+ # Format response
231
+ result = {
232
+ "surveyors": [
233
+ PersonaResponse(
234
+ id=persona.get("id", ""),
235
+ name=persona.get("name", "Unknown"),
236
+ role="surveyor",
237
+ description=persona.get("description", "")
238
+ )
239
+ for persona in surveyors
240
+ ],
241
+ "patients": [
242
+ PersonaResponse(
243
+ id=persona.get("id", ""),
244
+ name=persona.get("name", "Unknown"),
245
+ role="patient",
246
+ description=persona.get("description", "")
247
+ )
248
+ for persona in patients
249
+ ]
250
+ }
251
+
252
+ return result
253
+
254
+ except Exception as e:
255
+ logger.error(f"Error listing personas: {e}")
256
+ raise HTTPException(
257
+ status_code=500,
258
+ detail=f"Internal error listing personas: {str(e)}"
259
+ )
260
+
261
+
262
+ @router.get("/health")
263
+ async def health_check() -> Dict[str, str]:
264
+ """Health check endpoint for monitoring.
265
+
266
+ Returns:
267
+ Dict with service health status
268
+ """
269
+ try:
270
+ # Check if conversation service is available
271
+ service = get_conversation_service()
272
+
273
+ # Check if persona system is working
274
+ personas = persona_system.list_personas("surveyor")
275
+
276
+ return {
277
+ "status": "healthy",
278
+ "service": "conversation_api",
279
+ "personas_loaded": str(len(personas))
280
+ }
281
+
282
+ except Exception as e:
283
+ logger.error(f"Health check failed: {e}")
284
+ return {
285
+ "status": "unhealthy",
286
+ "error": str(e)
287
+ }
backend/api/services/conversation_service.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Conversation Service for managing active AI-to-AI conversations.
2
+
3
+ This service acts as the bridge between the WebSocket interface and the
4
+ ConversationManager. It handles the lifecycle of conversations, manages
5
+ active instances, and coordinates message streaming to connected clients.
6
+
7
+ Classes:
8
+ ConversationService: Main service for conversation management
9
+ ConversationInfo: Data class for conversation metadata
10
+
11
+ Example:
12
+ service = ConversationService(websocket_manager)
13
+ conversation_id = await service.start_conversation(
14
+ surveyor_id="surveyor_001",
15
+ patient_id="patient_001"
16
+ )
17
+ """
18
+
19
+ import asyncio
20
+ import logging
21
+ from datetime import datetime
22
+ from typing import Dict, Optional, Set
23
+ from dataclasses import dataclass
24
+ from enum import Enum
25
+ import sys
26
+ from pathlib import Path
27
+
28
+ # Add backend to path for imports
29
+ sys.path.insert(0, str(Path(__file__).parent.parent.parent))
30
+
31
+ from core.conversation_manager import ConversationManager, ConversationState
32
+ from core.persona_system import PersonaSystem
33
+ from api.websockets.conversation_ws import ConnectionManager
34
+
35
+ # Setup logging
36
+ logger = logging.getLogger(__name__)
37
+
38
+
39
+ class ConversationStatus(Enum):
40
+ """Status of managed conversations."""
41
+ STARTING = "starting"
42
+ RUNNING = "running"
43
+ PAUSED = "paused"
44
+ STOPPING = "stopping"
45
+ COMPLETED = "completed"
46
+ ERROR = "error"
47
+
48
+
49
+ @dataclass
50
+ class ConversationInfo:
51
+ """Information about an active conversation."""
52
+ conversation_id: str
53
+ surveyor_persona_id: str
54
+ patient_persona_id: str
55
+ status: ConversationStatus
56
+ created_at: datetime
57
+ message_count: int = 0
58
+ task: Optional[asyncio.Task] = None
59
+
60
+
61
+ class ConversationService:
62
+ """Service for managing AI-to-AI conversation instances.
63
+
64
+ This service coordinates between the ConversationManager and WebSocket
65
+ infrastructure to provide real-time conversation streaming to web clients.
66
+
67
+ Attributes:
68
+ websocket_manager: WebSocket connection manager for broadcasting
69
+ persona_system: Persona system for loading personas
70
+ active_conversations: Dict of active conversation instances
71
+ """
72
+
73
+ def __init__(self, websocket_manager: ConnectionManager):
74
+ """Initialize conversation service.
75
+
76
+ Args:
77
+ websocket_manager: WebSocket manager for message broadcasting
78
+ """
79
+ self.websocket_manager = websocket_manager
80
+ self.persona_system = PersonaSystem()
81
+ self.active_conversations: Dict[str, ConversationInfo] = {}
82
+
83
+ async def start_conversation(self,
84
+ conversation_id: str,
85
+ surveyor_persona_id: str,
86
+ patient_persona_id: str,
87
+ host: str = "http://localhost:11434",
88
+ model: str = "llama2:7b") -> bool:
89
+ """Start a new AI-to-AI conversation.
90
+
91
+ Args:
92
+ conversation_id: Unique identifier for the conversation
93
+ surveyor_persona_id: ID of the surveyor persona
94
+ patient_persona_id: ID of the patient persona
95
+ host: Ollama server host
96
+ model: LLM model to use
97
+
98
+ Returns:
99
+ True if conversation started successfully
100
+ """
101
+ if conversation_id in self.active_conversations:
102
+ logger.warning(f"Conversation {conversation_id} already exists")
103
+ return False
104
+
105
+ try:
106
+ # Load personas
107
+ surveyors = self.persona_system.list_personas("surveyor")
108
+ patients = self.persona_system.list_personas("patient")
109
+
110
+ surveyor_persona = next((p for p in surveyors if p.get("id") == surveyor_persona_id), None)
111
+ patient_persona = next((p for p in patients if p.get("id") == patient_persona_id), None)
112
+
113
+ if not surveyor_persona or not patient_persona:
114
+ await self._send_error(conversation_id, "Invalid persona IDs")
115
+ return False
116
+
117
+ # Create conversation info
118
+ conv_info = ConversationInfo(
119
+ conversation_id=conversation_id,
120
+ surveyor_persona_id=surveyor_persona_id,
121
+ patient_persona_id=patient_persona_id,
122
+ status=ConversationStatus.STARTING,
123
+ created_at=datetime.now()
124
+ )
125
+
126
+ self.active_conversations[conversation_id] = conv_info
127
+
128
+ # Send status update
129
+ await self._send_status_update(conversation_id, ConversationStatus.STARTING)
130
+
131
+ # Create and start conversation manager
132
+ manager = ConversationManager(
133
+ surveyor_persona=surveyor_persona,
134
+ patient_persona=patient_persona,
135
+ host=host,
136
+ model=model
137
+ )
138
+
139
+ # Start conversation streaming task
140
+ conv_info.task = asyncio.create_task(
141
+ self._stream_conversation(conversation_id, manager)
142
+ )
143
+
144
+ conv_info.status = ConversationStatus.RUNNING
145
+ await self._send_status_update(conversation_id, ConversationStatus.RUNNING)
146
+
147
+ logger.info(f"Started conversation {conversation_id}")
148
+ return True
149
+
150
+ except Exception as e:
151
+ logger.error(f"Failed to start conversation {conversation_id}: {e}")
152
+ await self._send_error(conversation_id, f"Failed to start conversation: {str(e)}")
153
+
154
+ # Clean up
155
+ if conversation_id in self.active_conversations:
156
+ del self.active_conversations[conversation_id]
157
+
158
+ return False
159
+
160
+ async def stop_conversation(self, conversation_id: str) -> bool:
161
+ """Stop an active conversation.
162
+
163
+ Args:
164
+ conversation_id: ID of conversation to stop
165
+
166
+ Returns:
167
+ True if conversation stopped successfully
168
+ """
169
+ if conversation_id not in self.active_conversations:
170
+ logger.warning(f"Conversation {conversation_id} not found")
171
+ return False
172
+
173
+ conv_info = self.active_conversations[conversation_id]
174
+
175
+ try:
176
+ conv_info.status = ConversationStatus.STOPPING
177
+ await self._send_status_update(conversation_id, ConversationStatus.STOPPING)
178
+
179
+ # Cancel the conversation task
180
+ if conv_info.task and not conv_info.task.done():
181
+ conv_info.task.cancel()
182
+ try:
183
+ await conv_info.task
184
+ except asyncio.CancelledError:
185
+ pass
186
+
187
+ # Update status and clean up
188
+ conv_info.status = ConversationStatus.COMPLETED
189
+ await self._send_status_update(conversation_id, ConversationStatus.COMPLETED)
190
+
191
+ del self.active_conversations[conversation_id]
192
+ logger.info(f"Stopped conversation {conversation_id}")
193
+ return True
194
+
195
+ except Exception as e:
196
+ logger.error(f"Error stopping conversation {conversation_id}: {e}")
197
+ conv_info.status = ConversationStatus.ERROR
198
+ await self._send_error(conversation_id, f"Error stopping conversation: {str(e)}")
199
+ return False
200
+
201
+ async def get_conversation_status(self, conversation_id: str) -> Optional[Dict]:
202
+ """Get status of a conversation.
203
+
204
+ Args:
205
+ conversation_id: ID of conversation
206
+
207
+ Returns:
208
+ Dict with conversation status or None if not found
209
+ """
210
+ if conversation_id not in self.active_conversations:
211
+ return None
212
+
213
+ conv_info = self.active_conversations[conversation_id]
214
+ return {
215
+ "conversation_id": conversation_id,
216
+ "status": conv_info.status.value,
217
+ "surveyor_persona_id": conv_info.surveyor_persona_id,
218
+ "patient_persona_id": conv_info.patient_persona_id,
219
+ "created_at": conv_info.created_at.isoformat(),
220
+ "message_count": conv_info.message_count
221
+ }
222
+
223
+ async def list_active_conversations(self) -> Dict[str, Dict]:
224
+ """List all active conversations.
225
+
226
+ Returns:
227
+ Dict mapping conversation_id to status info
228
+ """
229
+ result = {}
230
+ for conv_id, conv_info in self.active_conversations.items():
231
+ result[conv_id] = {
232
+ "status": conv_info.status.value,
233
+ "surveyor_persona_id": conv_info.surveyor_persona_id,
234
+ "patient_persona_id": conv_info.patient_persona_id,
235
+ "created_at": conv_info.created_at.isoformat(),
236
+ "message_count": conv_info.message_count
237
+ }
238
+ return result
239
+
240
+ async def _stream_conversation(self, conversation_id: str, manager: ConversationManager):
241
+ """Stream conversation messages to WebSocket clients.
242
+
243
+ Args:
244
+ conversation_id: ID of the conversation
245
+ manager: ConversationManager instance to stream from
246
+ """
247
+ conv_info = self.active_conversations.get(conversation_id)
248
+ if not conv_info:
249
+ return
250
+
251
+ try:
252
+ async for message in manager.conduct_conversation():
253
+ # Check if conversation was stopped
254
+ if conversation_id not in self.active_conversations:
255
+ break
256
+
257
+ # Update message count
258
+ conv_info.message_count += 1
259
+
260
+ # Add conversation metadata
261
+ websocket_message = {
262
+ "type": "conversation_message",
263
+ "conversation_id": conversation_id,
264
+ **message
265
+ }
266
+
267
+ # Send to all connected clients
268
+ await self.websocket_manager.send_to_conversation(
269
+ conversation_id, websocket_message
270
+ )
271
+
272
+ logger.debug(f"Streamed message {conv_info.message_count} for conversation {conversation_id}")
273
+
274
+ except asyncio.CancelledError:
275
+ logger.info(f"Conversation {conversation_id} streaming cancelled")
276
+ raise
277
+ except Exception as e:
278
+ logger.error(f"Error streaming conversation {conversation_id}: {e}")
279
+ conv_info.status = ConversationStatus.ERROR
280
+ await self._send_error(conversation_id, f"Streaming error: {str(e)}")
281
+ finally:
282
+ # Clean up conversation manager
283
+ try:
284
+ await manager.close()
285
+ except:
286
+ pass
287
+
288
+ # Mark as completed if not already in error state
289
+ if conv_info.status != ConversationStatus.ERROR:
290
+ conv_info.status = ConversationStatus.COMPLETED
291
+ await self._send_status_update(conversation_id, ConversationStatus.COMPLETED)
292
+
293
+ async def _send_status_update(self, conversation_id: str, status: ConversationStatus):
294
+ """Send conversation status update to clients.
295
+
296
+ Args:
297
+ conversation_id: ID of the conversation
298
+ status: New conversation status
299
+ """
300
+ message = {
301
+ "type": "conversation_status",
302
+ "conversation_id": conversation_id,
303
+ "status": status.value,
304
+ "timestamp": datetime.now().isoformat()
305
+ }
306
+
307
+ await self.websocket_manager.send_to_conversation(conversation_id, message)
308
+
309
+ async def _send_error(self, conversation_id: str, error_message: str):
310
+ """Send error message to clients.
311
+
312
+ Args:
313
+ conversation_id: ID of the conversation
314
+ error_message: Error description
315
+ """
316
+ message = {
317
+ "type": "conversation_error",
318
+ "conversation_id": conversation_id,
319
+ "error": error_message,
320
+ "timestamp": datetime.now().isoformat()
321
+ }
322
+
323
+ await self.websocket_manager.send_to_conversation(conversation_id, message)
324
+
325
+ async def cleanup(self):
326
+ """Clean up all active conversations."""
327
+ for conversation_id in list(self.active_conversations.keys()):
328
+ await self.stop_conversation(conversation_id)
329
+
330
+
331
+ # Global service instance (initialized in main.py)
332
+ conversation_service: Optional[ConversationService] = None
333
+
334
+
335
+ def get_conversation_service() -> ConversationService:
336
+ """Get the global conversation service instance.
337
+
338
+ Returns:
339
+ ConversationService instance
340
+
341
+ Raises:
342
+ RuntimeError: If service not initialized
343
+ """
344
+ if conversation_service is None:
345
+ raise RuntimeError("ConversationService not initialized")
346
+ return conversation_service
347
+
348
+
349
+ def initialize_conversation_service(websocket_manager: ConnectionManager):
350
+ """Initialize the global conversation service.
351
+
352
+ Args:
353
+ websocket_manager: WebSocket connection manager
354
+ """
355
+ global conversation_service
356
+ conversation_service = ConversationService(websocket_manager)
357
+ logger.info("ConversationService initialized")
backend/api/websockets/conversation_ws.py CHANGED
@@ -174,9 +174,10 @@ def validate_message(data: dict) -> bool:
174
  # Validate message types
175
  valid_types = [
176
  "conversation_message",
177
- "typing_indicator",
178
  "conversation_control",
179
- "heartbeat"
 
180
  ]
181
 
182
  if data["type"] not in valid_types:
@@ -206,7 +207,11 @@ async def handle_message(data: dict, conversation_id: str):
206
  elif message_type == "conversation_control":
207
  # Handle conversation control (start, pause, stop)
208
  await handle_conversation_control(data, conversation_id)
209
-
 
 
 
 
210
  elif message_type == "heartbeat":
211
  # Respond to heartbeat
212
  await manager.send_to_conversation(conversation_id, {
@@ -217,26 +222,114 @@ async def handle_message(data: dict, conversation_id: str):
217
 
218
  async def handle_conversation_control(data: dict, conversation_id: str):
219
  """Handle conversation control messages.
220
-
221
  Args:
222
  data: Control message data
223
  conversation_id: Target conversation ID
224
  """
225
  control_action = data.get("action")
226
-
227
- if control_action in ["start", "pause", "stop", "resume"]:
228
- # Broadcast control action to all participants
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  await manager.send_to_conversation(conversation_id, {
230
- "type": "conversation_control",
231
- "action": control_action,
232
- "conversation_id": conversation_id,
233
- "timestamp": datetime.now().isoformat(),
234
- "message": f"Conversation {control_action}ed"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  })
236
-
237
- logger.info(f"Conversation {conversation_id} {control_action}ed")
238
- else:
239
- logger.warning(f"Unknown control action: {control_action}")
240
 
241
 
242
  # Export the manager for use in other modules
 
174
  # Validate message types
175
  valid_types = [
176
  "conversation_message",
177
+ "typing_indicator",
178
  "conversation_control",
179
+ "heartbeat",
180
+ "start_conversation" # New message type
181
  ]
182
 
183
  if data["type"] not in valid_types:
 
207
  elif message_type == "conversation_control":
208
  # Handle conversation control (start, pause, stop)
209
  await handle_conversation_control(data, conversation_id)
210
+
211
+ elif message_type == "start_conversation":
212
+ # Handle starting a new conversation
213
+ await handle_start_conversation(data, conversation_id)
214
+
215
  elif message_type == "heartbeat":
216
  # Respond to heartbeat
217
  await manager.send_to_conversation(conversation_id, {
 
222
 
223
  async def handle_conversation_control(data: dict, conversation_id: str):
224
  """Handle conversation control messages.
225
+
226
  Args:
227
  data: Control message data
228
  conversation_id: Target conversation ID
229
  """
230
  control_action = data.get("action")
231
+
232
+ try:
233
+ # Import here to avoid circular imports
234
+ from ..services.conversation_service import get_conversation_service
235
+ service = get_conversation_service()
236
+
237
+ if control_action == "stop":
238
+ success = await service.stop_conversation(conversation_id)
239
+ if success:
240
+ await manager.send_to_conversation(conversation_id, {
241
+ "type": "conversation_control",
242
+ "action": "stop",
243
+ "conversation_id": conversation_id,
244
+ "timestamp": datetime.now().isoformat(),
245
+ "message": "Conversation stopped"
246
+ })
247
+ else:
248
+ await manager.send_to_conversation(conversation_id, {
249
+ "type": "error",
250
+ "error": "Failed to stop conversation",
251
+ "timestamp": datetime.now().isoformat()
252
+ })
253
+
254
+ elif control_action in ["pause", "resume"]:
255
+ # For now, just broadcast the action (pause/resume not fully implemented)
256
+ await manager.send_to_conversation(conversation_id, {
257
+ "type": "conversation_control",
258
+ "action": control_action,
259
+ "conversation_id": conversation_id,
260
+ "timestamp": datetime.now().isoformat(),
261
+ "message": f"Conversation {control_action}d"
262
+ })
263
+ logger.info(f"Conversation {conversation_id} {control_action}d")
264
+
265
+ else:
266
+ logger.warning(f"Unknown control action: {control_action}")
267
+ await manager.send_to_conversation(conversation_id, {
268
+ "type": "error",
269
+ "error": f"Unknown control action: {control_action}",
270
+ "timestamp": datetime.now().isoformat()
271
+ })
272
+
273
+ except Exception as e:
274
+ logger.error(f"Error handling conversation control: {e}")
275
  await manager.send_to_conversation(conversation_id, {
276
+ "type": "error",
277
+ "error": f"Control error: {str(e)}",
278
+ "timestamp": datetime.now().isoformat()
279
+ })
280
+
281
+
282
+ async def handle_start_conversation(data: dict, conversation_id: str):
283
+ """Handle starting a new conversation via WebSocket.
284
+
285
+ Args:
286
+ data: Start conversation message data
287
+ conversation_id: Target conversation ID
288
+ """
289
+ try:
290
+ # Import here to avoid circular imports
291
+ from ..services.conversation_service import get_conversation_service
292
+ service = get_conversation_service()
293
+
294
+ # Extract required fields
295
+ surveyor_persona_id = data.get("surveyor_persona_id")
296
+ patient_persona_id = data.get("patient_persona_id")
297
+ host = data.get("host", "http://localhost:11434")
298
+ model = data.get("model", "llama2:7b")
299
+
300
+ if not surveyor_persona_id or not patient_persona_id:
301
+ await manager.send_to_conversation(conversation_id, {
302
+ "type": "error",
303
+ "error": "Missing required persona IDs",
304
+ "timestamp": datetime.now().isoformat()
305
+ })
306
+ return
307
+
308
+ # Start the conversation
309
+ success = await service.start_conversation(
310
+ conversation_id=conversation_id,
311
+ surveyor_persona_id=surveyor_persona_id,
312
+ patient_persona_id=patient_persona_id,
313
+ host=host,
314
+ model=model
315
+ )
316
+
317
+ if success:
318
+ logger.info(f"Started conversation {conversation_id} via WebSocket")
319
+ else:
320
+ await manager.send_to_conversation(conversation_id, {
321
+ "type": "error",
322
+ "error": "Failed to start conversation",
323
+ "timestamp": datetime.now().isoformat()
324
+ })
325
+
326
+ except Exception as e:
327
+ logger.error(f"Error starting conversation via WebSocket: {e}")
328
+ await manager.send_to_conversation(conversation_id, {
329
+ "type": "error",
330
+ "error": f"Start error: {str(e)}",
331
+ "timestamp": datetime.now().isoformat()
332
  })
 
 
 
 
333
 
334
 
335
  # Export the manager for use in other modules
backend/core/persona_system.py CHANGED
@@ -10,7 +10,7 @@ both survey interviewers and patient respondents. It provides:
10
  Classes:
11
  PersonaSystem: Main persona management system
12
  PersonaPromptBuilder: Constructs prompts with persona context
13
-
14
  Example:
15
  system = PersonaSystem()
16
  prompt = system.build_conversation_prompt(
@@ -20,8 +20,10 @@ Example:
20
  )
21
  """
22
 
 
 
23
  import yaml
24
- from typing import Dict, List, Optional, Any
25
  from pathlib import Path
26
  from datetime import datetime
27
  import logging
@@ -218,7 +220,7 @@ class PersonaSystem:
218
  conversation_history: List[Dict[str, Any]] = None,
219
  current_context: Dict[str, Any] = None,
220
  user_prompt: str = "",
221
- max_history: int = 10) -> tuple[str, str]:
222
  """Build complete prompt for conversation generation.
223
 
224
  Args:
 
10
  Classes:
11
  PersonaSystem: Main persona management system
12
  PersonaPromptBuilder: Constructs prompts with persona context
13
+
14
  Example:
15
  system = PersonaSystem()
16
  prompt = system.build_conversation_prompt(
 
20
  )
21
  """
22
 
23
+ from __future__ import annotations
24
+
25
  import yaml
26
+ from typing import Dict, List, Optional, Any, Tuple
27
  from pathlib import Path
28
  from datetime import datetime
29
  import logging
 
220
  conversation_history: List[Dict[str, Any]] = None,
221
  current_context: Dict[str, Any] = None,
222
  user_prompt: str = "",
223
+ max_history: int = 10) -> Tuple[str, str]:
224
  """Build complete prompt for conversation generation.
225
 
226
  Args: