keepme-backend / src /controllers /_conversation_controller.py
ramanjitsingh1368's picture
refactor conversation handling to remove user_id dependency and update schemas for name and email
b8bc03f
import os
from fastapi import (
APIRouter,
HTTPException,
WebSocket,
WebSocketDisconnect,
)
from typing import List
from src.config import logger
from src.services import ConversationService
from src.schemas import (
CreateConversationSchema,
CreateWebrtcConnectionSchema,
CreateConversationSummarySchema,
CreateConversationResponse,
CreateConversationSummaryResponse,
CreateWebrtcConnectionResponse,
)
from src.utils import JWTUtil, RedisClient
class ConnectionManager:
def __init__(self):
self.jwt = JWTUtil()
self.active_connections: List[WebSocket] = []
async def connect(self, websocket: WebSocket, redis_client):
await websocket.accept()
conversation_id = websocket.query_params.get("conversation_id")
self.active_connections.append(websocket)
await self._update_redis_status(
unique_id=f"{conversation_id}",
redis_client=redis_client,
status="active",
)
async def disconnect(
self, websocket: WebSocket, redis_client, conversation_id: str
):
if websocket in self.active_connections:
self.active_connections.remove(websocket)
await self._update_redis_status(
unique_id=f"{conversation_id}",
redis_client=redis_client,
status=None,
)
async def _update_redis_status(self, unique_id: str, redis_client, status: str):
if status:
await redis_client.set(
f"session:{unique_id}", status, ex=os.getenv("REDIS_SESSION_EXPIRY")
)
else:
await redis_client.delete(f"session:{unique_id}")
redis_status = await redis_client.get(f"session:{unique_id}")
logger.info(f"Redis status for user {unique_id}: {redis_status}")
class ConversationController:
def __init__(self):
self.websocket_connection_manager = ConnectionManager()
self.redis_client = RedisClient().client
self.service = ConversationService
self.api_router = APIRouter()
self.websocket_router = APIRouter()
self.api_router.add_api_route(
"/conversations",
self.create_conversation,
methods=["POST"],
response_model=CreateConversationResponse,
)
self.api_router.add_api_route(
"/conversations/{conversation_id}",
self.create_webrtc_connection,
methods=["POST"],
response_model=CreateWebrtcConnectionResponse,
)
self.api_router.add_api_route(
"/conversations/{conversation_id}/summary",
self.create_conversation_summary,
methods=["POST"],
response_model=CreateConversationSummaryResponse,
)
self.websocket_router.add_websocket_route("/conversations", self.conversation)
async def create_conversation(self, data: CreateConversationSchema):
try:
async with self.service() as service:
return await service.create_conversation(
name=data.name, email=data.email, modality=data.modality
)
except HTTPException as e:
raise
except Exception as e:
logger.error(f"Error creating conversation: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to create conversation")
async def create_webrtc_connection(
self,
data: CreateWebrtcConnectionSchema,
):
try:
async with self.service() as service:
return await service.create_webrtc_connection(
conversation_id=data.conversation_id,
offer=data.offer.model_dump(),
)
except HTTPException as e:
raise
except Exception as e:
logger.error(f"Error creating WebRTC connection: {str(e)}")
raise HTTPException(
status_code=500, detail="Failed to create WebRTC connection"
)
async def conversation(self, websocket: WebSocket):
await self.websocket_connection_manager.connect(
websocket=websocket, redis_client=self.redis_client
)
conversation_id = websocket.query_params.get("conversation_id")
modality = websocket.query_params.get("modality")
try:
async with self.service() as service:
await service.conversation(
websocket=websocket,
conversation_id=conversation_id,
modality=modality,
redis_client=self.redis_client,
)
except WebSocketDisconnect:
await self.websocket_connection_manager.disconnect(
websocket=websocket,
redis_client=self.redis_client,
conversation_id=conversation_id,
)
logger.info("WebSocket connection closed")
except HTTPException as e:
raise
except Exception as e:
logger.error(f"Error in WebSocket conversation: {str(e)}")
await self.websocket_connection_manager.disconnect(
websocket=websocket,
redis_client=self.redis_client,
conversation_id=conversation_id,
)
async def create_conversation_summary(self, data: CreateConversationSummarySchema):
try:
async with self.service() as service:
return await service.create_conversation_summary(
conversation_id=data.conversation_id
)
except HTTPException as e:
raise
except Exception as e:
logger.error(f"Error creating conversation summary: {str(e)}")
raise HTTPException(
status_code=500, detail="Failed to create conversation summary"
)