Spaces:
Runtime error
Runtime error
refactor conversation handling to remove user_id dependency and update schemas for name and email
b8bc03f
| import os | |
| from fastapi import ( | |
| APIRouter, | |
| HTTPException, | |
| WebSocket, | |
| WebSocketDisconnect, | |
| ) | |
| from typing import List | |
| from src.config import logger | |
| from src.services import ConversationService | |
| from src.schemas import ( | |
| CreateConversationSchema, | |
| CreateWebrtcConnectionSchema, | |
| CreateConversationSummarySchema, | |
| CreateConversationResponse, | |
| CreateConversationSummaryResponse, | |
| CreateWebrtcConnectionResponse, | |
| ) | |
| from src.utils import JWTUtil, RedisClient | |
| class ConnectionManager: | |
| def __init__(self): | |
| self.jwt = JWTUtil() | |
| self.active_connections: List[WebSocket] = [] | |
| async def connect(self, websocket: WebSocket, redis_client): | |
| await websocket.accept() | |
| conversation_id = websocket.query_params.get("conversation_id") | |
| self.active_connections.append(websocket) | |
| await self._update_redis_status( | |
| unique_id=f"{conversation_id}", | |
| redis_client=redis_client, | |
| status="active", | |
| ) | |
| async def disconnect( | |
| self, websocket: WebSocket, redis_client, conversation_id: str | |
| ): | |
| if websocket in self.active_connections: | |
| self.active_connections.remove(websocket) | |
| await self._update_redis_status( | |
| unique_id=f"{conversation_id}", | |
| redis_client=redis_client, | |
| status=None, | |
| ) | |
| async def _update_redis_status(self, unique_id: str, redis_client, status: str): | |
| if status: | |
| await redis_client.set( | |
| f"session:{unique_id}", status, ex=os.getenv("REDIS_SESSION_EXPIRY") | |
| ) | |
| else: | |
| await redis_client.delete(f"session:{unique_id}") | |
| redis_status = await redis_client.get(f"session:{unique_id}") | |
| logger.info(f"Redis status for user {unique_id}: {redis_status}") | |
| class ConversationController: | |
| def __init__(self): | |
| self.websocket_connection_manager = ConnectionManager() | |
| self.redis_client = RedisClient().client | |
| self.service = ConversationService | |
| self.api_router = APIRouter() | |
| self.websocket_router = APIRouter() | |
| self.api_router.add_api_route( | |
| "/conversations", | |
| self.create_conversation, | |
| methods=["POST"], | |
| response_model=CreateConversationResponse, | |
| ) | |
| self.api_router.add_api_route( | |
| "/conversations/{conversation_id}", | |
| self.create_webrtc_connection, | |
| methods=["POST"], | |
| response_model=CreateWebrtcConnectionResponse, | |
| ) | |
| self.api_router.add_api_route( | |
| "/conversations/{conversation_id}/summary", | |
| self.create_conversation_summary, | |
| methods=["POST"], | |
| response_model=CreateConversationSummaryResponse, | |
| ) | |
| self.websocket_router.add_websocket_route("/conversations", self.conversation) | |
| async def create_conversation(self, data: CreateConversationSchema): | |
| try: | |
| async with self.service() as service: | |
| return await service.create_conversation( | |
| name=data.name, email=data.email, modality=data.modality | |
| ) | |
| except HTTPException as e: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error creating conversation: {str(e)}") | |
| raise HTTPException(status_code=500, detail="Failed to create conversation") | |
| async def create_webrtc_connection( | |
| self, | |
| data: CreateWebrtcConnectionSchema, | |
| ): | |
| try: | |
| async with self.service() as service: | |
| return await service.create_webrtc_connection( | |
| conversation_id=data.conversation_id, | |
| offer=data.offer.model_dump(), | |
| ) | |
| except HTTPException as e: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error creating WebRTC connection: {str(e)}") | |
| raise HTTPException( | |
| status_code=500, detail="Failed to create WebRTC connection" | |
| ) | |
| async def conversation(self, websocket: WebSocket): | |
| await self.websocket_connection_manager.connect( | |
| websocket=websocket, redis_client=self.redis_client | |
| ) | |
| conversation_id = websocket.query_params.get("conversation_id") | |
| modality = websocket.query_params.get("modality") | |
| try: | |
| async with self.service() as service: | |
| await service.conversation( | |
| websocket=websocket, | |
| conversation_id=conversation_id, | |
| modality=modality, | |
| redis_client=self.redis_client, | |
| ) | |
| except WebSocketDisconnect: | |
| await self.websocket_connection_manager.disconnect( | |
| websocket=websocket, | |
| redis_client=self.redis_client, | |
| conversation_id=conversation_id, | |
| ) | |
| logger.info("WebSocket connection closed") | |
| except HTTPException as e: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error in WebSocket conversation: {str(e)}") | |
| await self.websocket_connection_manager.disconnect( | |
| websocket=websocket, | |
| redis_client=self.redis_client, | |
| conversation_id=conversation_id, | |
| ) | |
| async def create_conversation_summary(self, data: CreateConversationSummarySchema): | |
| try: | |
| async with self.service() as service: | |
| return await service.create_conversation_summary( | |
| conversation_id=data.conversation_id | |
| ) | |
| except HTTPException as e: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error creating conversation summary: {str(e)}") | |
| raise HTTPException( | |
| status_code=500, detail="Failed to create conversation summary" | |
| ) | |