from fastapi import WebSocket from src.config import logger from src.utils import OpenAIClient from src.schemas import UserSignInSchema from src.models import Conversation, User, Message from src.repositories import ConversationRepository, MessageRepository, UserRepository from ._auth_service import AuthService from ._websocket_service import WebSocketService from ._web_rtc_service import WebRTCService class ConversationService: def __init__(self): self.openai_client = OpenAIClient self.websocket_service = WebSocketService self.web_rtc_service = WebRTCService self.auth_service = AuthService() self.conversation_repository = ConversationRepository() self.message_repository = MessageRepository() self.user_repository = UserRepository() async def __aenter__(self): return self async def __aexit__(self, *args): pass async def validate_conversation(self, user_id, conversation_id) -> bool: user: User = await self.user_repository.get_by_id(user_id) conversation_object: Conversation = ( await self.conversation_repository.get_by_id(conversation_id) ) if not conversation_object: return False elif conversation_object.user.id != user.id: return False return True async def get_conversation_history(self, conversation_id, find_all=False, limit=20): conversation_object: Conversation = ( await self.conversation_repository.get_by_id(conversation_id) ) messages: list[Message] = ( await self.message_repository.get_messages_by_conversation( conversation_object ) ) if len(messages) == 0: return {"conversation_history": ""} if not find_all: messages = messages[-limit:] conversation_history = "\n".join( [f"{message.role}: {message.content}" for message in messages] ) return {"conversation_history": conversation_history} async def create_conversation(self, name, email, modality): user_id_object = await self.auth_service.sign_in( UserSignInSchema(name=name, email=email) ) user_object = await self.user_repository.get_by_id(user_id_object["user_id"]) conversation_object = await self.conversation_repository.insert_one( Conversation(user=user_object, modality=modality, summary="") ) return {"conversation_id": f"{conversation_object.id}"} async def create_webrtc_connection(self, conversation_id, offer): conversation_object: Conversation = ( await self.conversation_repository.get_by_id(conversation_id) ) if conversation_object.modality != "voice": raise Exception("Invalid modality") async with self.openai_client() as client: session_data = await client.create_openai_session(text_mode_only=False) ephemeral_key = session_data["client_secret"]["value"] response = await client.create_webrtc_connection( ephemeral_key=ephemeral_key, offer=offer ) return response async def conversation( self, websocket: WebSocket, conversation_id, modality, redis_client ): conversation_object: Conversation = ( await self.conversation_repository.get_by_id(conversation_id) ) user_id = conversation_object.user.id if modality not in ["text", "voice"]: await websocket.close(code=1008, reason="Unsupported modality") raise Exception("Unsupported modality") conversation_history = await self.get_conversation_history( conversation_id, find_all=False, limit=20 ) if modality == "voice": async with self.web_rtc_service() as rtc_service: await rtc_service.handle_voice_conversion( websocket=websocket, conversation_id=conversation_id, conversation_history=conversation_history, user_id=user_id, redis_client=redis_client, ) else: async with self.websocket_service() as ws_service: await ws_service.handle_text_conversation_with_openai_websockets( client_websocket=websocket, conversation_id=conversation_id, conversation_history=conversation_history, user_id=user_id, redis_client=redis_client, ) async def create_conversation_summary(self, user_id, conversation_id): conversation_object: Conversation = ( await self.conversation_repository.get_by_id(conversation_id) ) user_id = conversation_object.user.id conversation_history = await self.get_conversation_history( conversation_id, find_all=True ) if conversation_history["conversation_history"] == "": conversation_object.summary = "" updated_converation: Conversation = ( await self.conversation_repository.update( conversation_object.id, conversation_object ) ) return {"conversation_summary": updated_converation.summary} async with self.openai_client() as client: conversation_summary = await client.text_generation( query=f"Generate Conversation Summary\n\n {conversation_history}" ) conversation_object.summary = conversation_summary updated_converation: Conversation = await self.conversation_repository.update( conversation_object.id, conversation_object ) return {"conversation_summary": updated_converation.summary}