Spaces:
Runtime error
Runtime error
File size: 5,918 Bytes
eb474ee 2efe331 b8bc03f be1d374 1e0a76d 2efe331 b8bc03f 2efe331 d4f6849 eb474ee 2efe331 d4f6849 b8bc03f 1e0a76d eb474ee af7af6c 1c4497b af7af6c 1c4497b af7af6c 33fe7cc b8bc03f 1e0a76d 4267717 1e0a76d 4267717 eb474ee b8bc03f 4267717 eb474ee 4267717 eb474ee 4267717 eb474ee af7af6c b8bc03f af7af6c b8bc03f d4f6849 33fe7cc d4f6849 33fe7cc 00662bb d4f6849 33fe7cc 00662bb d4f6849 af7af6c 1e0a76d b8bc03f 33fe7cc 1e0a76d be1d374 33fe7cc be1d374 1e0a76d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | from fastapi import WebSocket
from src.config import logger
from src.utils import OpenAIClient
from src.schemas import UserSignInSchema
from src.models import Conversation, User, Message
from src.repositories import ConversationRepository, MessageRepository, UserRepository
from ._auth_service import AuthService
from ._websocket_service import WebSocketService
from ._web_rtc_service import WebRTCService
class ConversationService:
def __init__(self):
self.openai_client = OpenAIClient
self.websocket_service = WebSocketService
self.web_rtc_service = WebRTCService
self.auth_service = AuthService()
self.conversation_repository = ConversationRepository()
self.message_repository = MessageRepository()
self.user_repository = UserRepository()
async def __aenter__(self):
return self
async def __aexit__(self, *args):
pass
async def validate_conversation(self, user_id, conversation_id) -> bool:
user: User = await self.user_repository.get_by_id(user_id)
conversation_object: Conversation = (
await self.conversation_repository.get_by_id(conversation_id)
)
if not conversation_object:
return False
elif conversation_object.user.id != user.id:
return False
return True
async def get_conversation_history(self, conversation_id, find_all=False, limit=20):
conversation_object: Conversation = (
await self.conversation_repository.get_by_id(conversation_id)
)
messages: list[Message] = (
await self.message_repository.get_messages_by_conversation(
conversation_object
)
)
if len(messages) == 0:
return {"conversation_history": ""}
if not find_all:
messages = messages[-limit:]
conversation_history = "\n".join(
[f"{message.role}: {message.content}" for message in messages]
)
return {"conversation_history": conversation_history}
async def create_conversation(self, name, email, modality):
user_id_object = await self.auth_service.sign_in(
UserSignInSchema(name=name, email=email)
)
user_object = await self.user_repository.get_by_id(user_id_object["user_id"])
conversation_object = await self.conversation_repository.insert_one(
Conversation(user=user_object, modality=modality, summary="")
)
return {"conversation_id": f"{conversation_object.id}"}
async def create_webrtc_connection(self, conversation_id, offer):
conversation_object: Conversation = (
await self.conversation_repository.get_by_id(conversation_id)
)
if conversation_object.modality != "voice":
raise Exception("Invalid modality")
async with self.openai_client() as client:
session_data = await client.create_openai_session(text_mode_only=False)
ephemeral_key = session_data["client_secret"]["value"]
response = await client.create_webrtc_connection(
ephemeral_key=ephemeral_key, offer=offer
)
return response
async def conversation(
self, websocket: WebSocket, conversation_id, modality, redis_client
):
conversation_object: Conversation = (
await self.conversation_repository.get_by_id(conversation_id)
)
user_id = conversation_object.user.id
if modality not in ["text", "voice"]:
await websocket.close(code=1008, reason="Unsupported modality")
raise Exception("Unsupported modality")
conversation_history = await self.get_conversation_history(
conversation_id, find_all=False, limit=20
)
if modality == "voice":
async with self.web_rtc_service() as rtc_service:
await rtc_service.handle_voice_conversion(
websocket=websocket,
conversation_id=conversation_id,
conversation_history=conversation_history,
user_id=user_id,
redis_client=redis_client,
)
else:
async with self.websocket_service() as ws_service:
await ws_service.handle_text_conversation_with_openai_websockets(
client_websocket=websocket,
conversation_id=conversation_id,
conversation_history=conversation_history,
user_id=user_id,
redis_client=redis_client,
)
async def create_conversation_summary(self, user_id, conversation_id):
conversation_object: Conversation = (
await self.conversation_repository.get_by_id(conversation_id)
)
user_id = conversation_object.user.id
conversation_history = await self.get_conversation_history(
conversation_id, find_all=True
)
if conversation_history["conversation_history"] == "":
conversation_object.summary = ""
updated_converation: Conversation = (
await self.conversation_repository.update(
conversation_object.id, conversation_object
)
)
return {"conversation_summary": updated_converation.summary}
async with self.openai_client() as client:
conversation_summary = await client.text_generation(
query=f"Generate Conversation Summary\n\n {conversation_history}"
)
conversation_object.summary = conversation_summary
updated_converation: Conversation = await self.conversation_repository.update(
conversation_object.id, conversation_object
)
return {"conversation_summary": updated_converation.summary}
|