import os from fastapi import ( APIRouter, HTTPException, WebSocket, WebSocketDisconnect, ) from typing import List from src.config import logger from src.services import ConversationService from src.schemas import ( CreateConversationSchema, CreateWebrtcConnectionSchema, CreateConversationSummarySchema, CreateConversationResponse, CreateConversationSummaryResponse, CreateWebrtcConnectionResponse, ) from src.utils import JWTUtil, RedisClient class ConnectionManager: def __init__(self): self.jwt = JWTUtil() self.active_connections: List[WebSocket] = [] async def connect(self, websocket: WebSocket, redis_client): await websocket.accept() conversation_id = websocket.query_params.get("conversation_id") self.active_connections.append(websocket) await self._update_redis_status( unique_id=f"{conversation_id}", redis_client=redis_client, status="active", ) async def disconnect( self, websocket: WebSocket, redis_client, conversation_id: str ): if websocket in self.active_connections: self.active_connections.remove(websocket) await self._update_redis_status( unique_id=f"{conversation_id}", redis_client=redis_client, status=None, ) async def _update_redis_status(self, unique_id: str, redis_client, status: str): if status: await redis_client.set( f"session:{unique_id}", status, ex=os.getenv("REDIS_SESSION_EXPIRY") ) else: await redis_client.delete(f"session:{unique_id}") redis_status = await redis_client.get(f"session:{unique_id}") logger.info(f"Redis status for user {unique_id}: {redis_status}") class ConversationController: def __init__(self): self.websocket_connection_manager = ConnectionManager() self.redis_client = RedisClient().client self.service = ConversationService self.api_router = APIRouter() self.websocket_router = APIRouter() self.api_router.add_api_route( "/conversations", self.create_conversation, methods=["POST"], response_model=CreateConversationResponse, ) self.api_router.add_api_route( "/conversations/{conversation_id}", self.create_webrtc_connection, methods=["POST"], response_model=CreateWebrtcConnectionResponse, ) self.api_router.add_api_route( "/conversations/{conversation_id}/summary", self.create_conversation_summary, methods=["POST"], response_model=CreateConversationSummaryResponse, ) self.websocket_router.add_websocket_route("/conversations", self.conversation) async def create_conversation(self, data: CreateConversationSchema): try: async with self.service() as service: return await service.create_conversation( name=data.name, email=data.email, modality=data.modality ) except HTTPException as e: raise except Exception as e: logger.error(f"Error creating conversation: {str(e)}") raise HTTPException(status_code=500, detail="Failed to create conversation") async def create_webrtc_connection( self, data: CreateWebrtcConnectionSchema, ): try: async with self.service() as service: return await service.create_webrtc_connection( conversation_id=data.conversation_id, offer=data.offer.model_dump(), ) except HTTPException as e: raise except Exception as e: logger.error(f"Error creating WebRTC connection: {str(e)}") raise HTTPException( status_code=500, detail="Failed to create WebRTC connection" ) async def conversation(self, websocket: WebSocket): await self.websocket_connection_manager.connect( websocket=websocket, redis_client=self.redis_client ) conversation_id = websocket.query_params.get("conversation_id") modality = websocket.query_params.get("modality") try: async with self.service() as service: await service.conversation( websocket=websocket, conversation_id=conversation_id, modality=modality, redis_client=self.redis_client, ) except WebSocketDisconnect: await self.websocket_connection_manager.disconnect( websocket=websocket, redis_client=self.redis_client, conversation_id=conversation_id, ) logger.info("WebSocket connection closed") except HTTPException as e: raise except Exception as e: logger.error(f"Error in WebSocket conversation: {str(e)}") await self.websocket_connection_manager.disconnect( websocket=websocket, redis_client=self.redis_client, conversation_id=conversation_id, ) async def create_conversation_summary(self, data: CreateConversationSummarySchema): try: async with self.service() as service: return await service.create_conversation_summary( conversation_id=data.conversation_id ) except HTTPException as e: raise except Exception as e: logger.error(f"Error creating conversation summary: {str(e)}") raise HTTPException( status_code=500, detail="Failed to create conversation summary" )