Spaces:
Runtime error
Runtime error
refactor conversation handling to remove user_id dependency and update schemas for name and email
b8bc03f | from fastapi import WebSocket | |
| from src.config import logger | |
| from src.utils import OpenAIClient | |
| from src.schemas import UserSignInSchema | |
| from src.models import Conversation, User, Message | |
| from src.repositories import ConversationRepository, MessageRepository, UserRepository | |
| from ._auth_service import AuthService | |
| from ._websocket_service import WebSocketService | |
| from ._web_rtc_service import WebRTCService | |
| class ConversationService: | |
| def __init__(self): | |
| self.openai_client = OpenAIClient | |
| self.websocket_service = WebSocketService | |
| self.web_rtc_service = WebRTCService | |
| self.auth_service = AuthService() | |
| self.conversation_repository = ConversationRepository() | |
| self.message_repository = MessageRepository() | |
| self.user_repository = UserRepository() | |
| async def __aenter__(self): | |
| return self | |
| async def __aexit__(self, *args): | |
| pass | |
| async def validate_conversation(self, user_id, conversation_id) -> bool: | |
| user: User = await self.user_repository.get_by_id(user_id) | |
| conversation_object: Conversation = ( | |
| await self.conversation_repository.get_by_id(conversation_id) | |
| ) | |
| if not conversation_object: | |
| return False | |
| elif conversation_object.user.id != user.id: | |
| return False | |
| return True | |
| async def get_conversation_history(self, conversation_id, find_all=False, limit=20): | |
| conversation_object: Conversation = ( | |
| await self.conversation_repository.get_by_id(conversation_id) | |
| ) | |
| messages: list[Message] = ( | |
| await self.message_repository.get_messages_by_conversation( | |
| conversation_object | |
| ) | |
| ) | |
| if len(messages) == 0: | |
| return {"conversation_history": ""} | |
| if not find_all: | |
| messages = messages[-limit:] | |
| conversation_history = "\n".join( | |
| [f"{message.role}: {message.content}" for message in messages] | |
| ) | |
| return {"conversation_history": conversation_history} | |
| async def create_conversation(self, name, email, modality): | |
| user_id_object = await self.auth_service.sign_in( | |
| UserSignInSchema(name=name, email=email) | |
| ) | |
| user_object = await self.user_repository.get_by_id(user_id_object["user_id"]) | |
| conversation_object = await self.conversation_repository.insert_one( | |
| Conversation(user=user_object, modality=modality, summary="") | |
| ) | |
| return {"conversation_id": f"{conversation_object.id}"} | |
| async def create_webrtc_connection(self, conversation_id, offer): | |
| conversation_object: Conversation = ( | |
| await self.conversation_repository.get_by_id(conversation_id) | |
| ) | |
| if conversation_object.modality != "voice": | |
| raise Exception("Invalid modality") | |
| async with self.openai_client() as client: | |
| session_data = await client.create_openai_session(text_mode_only=False) | |
| ephemeral_key = session_data["client_secret"]["value"] | |
| response = await client.create_webrtc_connection( | |
| ephemeral_key=ephemeral_key, offer=offer | |
| ) | |
| return response | |
| async def conversation( | |
| self, websocket: WebSocket, conversation_id, modality, redis_client | |
| ): | |
| conversation_object: Conversation = ( | |
| await self.conversation_repository.get_by_id(conversation_id) | |
| ) | |
| user_id = conversation_object.user.id | |
| if modality not in ["text", "voice"]: | |
| await websocket.close(code=1008, reason="Unsupported modality") | |
| raise Exception("Unsupported modality") | |
| conversation_history = await self.get_conversation_history( | |
| conversation_id, find_all=False, limit=20 | |
| ) | |
| if modality == "voice": | |
| async with self.web_rtc_service() as rtc_service: | |
| await rtc_service.handle_voice_conversion( | |
| websocket=websocket, | |
| conversation_id=conversation_id, | |
| conversation_history=conversation_history, | |
| user_id=user_id, | |
| redis_client=redis_client, | |
| ) | |
| else: | |
| async with self.websocket_service() as ws_service: | |
| await ws_service.handle_text_conversation_with_openai_websockets( | |
| client_websocket=websocket, | |
| conversation_id=conversation_id, | |
| conversation_history=conversation_history, | |
| user_id=user_id, | |
| redis_client=redis_client, | |
| ) | |
| async def create_conversation_summary(self, user_id, conversation_id): | |
| conversation_object: Conversation = ( | |
| await self.conversation_repository.get_by_id(conversation_id) | |
| ) | |
| user_id = conversation_object.user.id | |
| conversation_history = await self.get_conversation_history( | |
| conversation_id, find_all=True | |
| ) | |
| if conversation_history["conversation_history"] == "": | |
| conversation_object.summary = "" | |
| updated_converation: Conversation = ( | |
| await self.conversation_repository.update( | |
| conversation_object.id, conversation_object | |
| ) | |
| ) | |
| return {"conversation_summary": updated_converation.summary} | |
| async with self.openai_client() as client: | |
| conversation_summary = await client.text_generation( | |
| query=f"Generate Conversation Summary\n\n {conversation_history}" | |
| ) | |
| conversation_object.summary = conversation_summary | |
| updated_converation: Conversation = await self.conversation_repository.update( | |
| conversation_object.id, conversation_object | |
| ) | |
| return {"conversation_summary": updated_converation.summary} | |