keepme-backend / src /services /_conversation_service.py
ramanjitsingh1368's picture
refactor conversation handling to remove user_id dependency and update schemas for name and email
b8bc03f
from fastapi import WebSocket
from src.config import logger
from src.utils import OpenAIClient
from src.schemas import UserSignInSchema
from src.models import Conversation, User, Message
from src.repositories import ConversationRepository, MessageRepository, UserRepository
from ._auth_service import AuthService
from ._websocket_service import WebSocketService
from ._web_rtc_service import WebRTCService
class ConversationService:
def __init__(self):
self.openai_client = OpenAIClient
self.websocket_service = WebSocketService
self.web_rtc_service = WebRTCService
self.auth_service = AuthService()
self.conversation_repository = ConversationRepository()
self.message_repository = MessageRepository()
self.user_repository = UserRepository()
async def __aenter__(self):
return self
async def __aexit__(self, *args):
pass
async def validate_conversation(self, user_id, conversation_id) -> bool:
user: User = await self.user_repository.get_by_id(user_id)
conversation_object: Conversation = (
await self.conversation_repository.get_by_id(conversation_id)
)
if not conversation_object:
return False
elif conversation_object.user.id != user.id:
return False
return True
async def get_conversation_history(self, conversation_id, find_all=False, limit=20):
conversation_object: Conversation = (
await self.conversation_repository.get_by_id(conversation_id)
)
messages: list[Message] = (
await self.message_repository.get_messages_by_conversation(
conversation_object
)
)
if len(messages) == 0:
return {"conversation_history": ""}
if not find_all:
messages = messages[-limit:]
conversation_history = "\n".join(
[f"{message.role}: {message.content}" for message in messages]
)
return {"conversation_history": conversation_history}
async def create_conversation(self, name, email, modality):
user_id_object = await self.auth_service.sign_in(
UserSignInSchema(name=name, email=email)
)
user_object = await self.user_repository.get_by_id(user_id_object["user_id"])
conversation_object = await self.conversation_repository.insert_one(
Conversation(user=user_object, modality=modality, summary="")
)
return {"conversation_id": f"{conversation_object.id}"}
async def create_webrtc_connection(self, conversation_id, offer):
conversation_object: Conversation = (
await self.conversation_repository.get_by_id(conversation_id)
)
if conversation_object.modality != "voice":
raise Exception("Invalid modality")
async with self.openai_client() as client:
session_data = await client.create_openai_session(text_mode_only=False)
ephemeral_key = session_data["client_secret"]["value"]
response = await client.create_webrtc_connection(
ephemeral_key=ephemeral_key, offer=offer
)
return response
async def conversation(
self, websocket: WebSocket, conversation_id, modality, redis_client
):
conversation_object: Conversation = (
await self.conversation_repository.get_by_id(conversation_id)
)
user_id = conversation_object.user.id
if modality not in ["text", "voice"]:
await websocket.close(code=1008, reason="Unsupported modality")
raise Exception("Unsupported modality")
conversation_history = await self.get_conversation_history(
conversation_id, find_all=False, limit=20
)
if modality == "voice":
async with self.web_rtc_service() as rtc_service:
await rtc_service.handle_voice_conversion(
websocket=websocket,
conversation_id=conversation_id,
conversation_history=conversation_history,
user_id=user_id,
redis_client=redis_client,
)
else:
async with self.websocket_service() as ws_service:
await ws_service.handle_text_conversation_with_openai_websockets(
client_websocket=websocket,
conversation_id=conversation_id,
conversation_history=conversation_history,
user_id=user_id,
redis_client=redis_client,
)
async def create_conversation_summary(self, user_id, conversation_id):
conversation_object: Conversation = (
await self.conversation_repository.get_by_id(conversation_id)
)
user_id = conversation_object.user.id
conversation_history = await self.get_conversation_history(
conversation_id, find_all=True
)
if conversation_history["conversation_history"] == "":
conversation_object.summary = ""
updated_converation: Conversation = (
await self.conversation_repository.update(
conversation_object.id, conversation_object
)
)
return {"conversation_summary": updated_converation.summary}
async with self.openai_client() as client:
conversation_summary = await client.text_generation(
query=f"Generate Conversation Summary\n\n {conversation_history}"
)
conversation_object.summary = conversation_summary
updated_converation: Conversation = await self.conversation_repository.update(
conversation_object.id, conversation_object
)
return {"conversation_summary": updated_converation.summary}