File size: 5,918 Bytes
eb474ee
 
 
2efe331
b8bc03f
be1d374
1e0a76d
2efe331
b8bc03f
2efe331
d4f6849
eb474ee
 
 
 
 
2efe331
d4f6849
b8bc03f
1e0a76d
 
 
eb474ee
 
 
 
 
 
 
af7af6c
1c4497b
af7af6c
 
 
 
 
 
 
1c4497b
af7af6c
33fe7cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8bc03f
 
 
 
 
1e0a76d
4267717
1e0a76d
4267717
eb474ee
b8bc03f
4267717
 
 
 
 
 
eb474ee
4267717
 
 
eb474ee
 
4267717
eb474ee
af7af6c
b8bc03f
af7af6c
b8bc03f
 
 
 
d4f6849
 
 
 
33fe7cc
 
 
 
d4f6849
 
 
33fe7cc
 
 
 
00662bb
d4f6849
 
 
 
33fe7cc
 
 
 
00662bb
d4f6849
af7af6c
 
1e0a76d
 
 
b8bc03f
33fe7cc
 
1e0a76d
be1d374
33fe7cc
 
 
 
 
 
 
 
be1d374
 
 
 
 
 
1e0a76d
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
from fastapi import WebSocket

from src.config import logger
from src.utils import OpenAIClient
from src.schemas import UserSignInSchema
from src.models import Conversation, User, Message
from src.repositories import ConversationRepository, MessageRepository, UserRepository

from ._auth_service import AuthService
from ._websocket_service import WebSocketService
from ._web_rtc_service import WebRTCService


class ConversationService:
    def __init__(self):
        self.openai_client = OpenAIClient
        self.websocket_service = WebSocketService
        self.web_rtc_service = WebRTCService
        self.auth_service = AuthService()
        self.conversation_repository = ConversationRepository()
        self.message_repository = MessageRepository()
        self.user_repository = UserRepository()

    async def __aenter__(self):
        return self

    async def __aexit__(self, *args):
        pass

    async def validate_conversation(self, user_id, conversation_id) -> bool:
        user: User = await self.user_repository.get_by_id(user_id)
        conversation_object: Conversation = (
            await self.conversation_repository.get_by_id(conversation_id)
        )
        if not conversation_object:
            return False
        elif conversation_object.user.id != user.id:
            return False
        return True

    async def get_conversation_history(self, conversation_id, find_all=False, limit=20):
        conversation_object: Conversation = (
            await self.conversation_repository.get_by_id(conversation_id)
        )
        messages: list[Message] = (
            await self.message_repository.get_messages_by_conversation(
                conversation_object
            )
        )
        if len(messages) == 0:
            return {"conversation_history": ""}

        if not find_all:
            messages = messages[-limit:]

        conversation_history = "\n".join(
            [f"{message.role}: {message.content}" for message in messages]
        )
        return {"conversation_history": conversation_history}

    async def create_conversation(self, name, email, modality):
        user_id_object = await self.auth_service.sign_in(
            UserSignInSchema(name=name, email=email)
        )
        user_object = await self.user_repository.get_by_id(user_id_object["user_id"])
        conversation_object = await self.conversation_repository.insert_one(
            Conversation(user=user_object, modality=modality, summary="")
        )
        return {"conversation_id": f"{conversation_object.id}"}

    async def create_webrtc_connection(self, conversation_id, offer):
        conversation_object: Conversation = (
            await self.conversation_repository.get_by_id(conversation_id)
        )
        if conversation_object.modality != "voice":
            raise Exception("Invalid modality")

        async with self.openai_client() as client:
            session_data = await client.create_openai_session(text_mode_only=False)
            ephemeral_key = session_data["client_secret"]["value"]
            response = await client.create_webrtc_connection(
                ephemeral_key=ephemeral_key, offer=offer
            )
            return response

    async def conversation(
        self, websocket: WebSocket, conversation_id, modality, redis_client
    ):
        conversation_object: Conversation = (
            await self.conversation_repository.get_by_id(conversation_id)
        )
        user_id = conversation_object.user.id
        if modality not in ["text", "voice"]:
            await websocket.close(code=1008, reason="Unsupported modality")
            raise Exception("Unsupported modality")

        conversation_history = await self.get_conversation_history(
            conversation_id, find_all=False, limit=20
        )

        if modality == "voice":
            async with self.web_rtc_service() as rtc_service:
                await rtc_service.handle_voice_conversion(
                    websocket=websocket,
                    conversation_id=conversation_id,
                    conversation_history=conversation_history,
                    user_id=user_id,
                    redis_client=redis_client,
                )
        else:
            async with self.websocket_service() as ws_service:
                await ws_service.handle_text_conversation_with_openai_websockets(
                    client_websocket=websocket,
                    conversation_id=conversation_id,
                    conversation_history=conversation_history,
                    user_id=user_id,
                    redis_client=redis_client,
                )

    async def create_conversation_summary(self, user_id, conversation_id):
        conversation_object: Conversation = (
            await self.conversation_repository.get_by_id(conversation_id)
        )
        user_id = conversation_object.user.id
        conversation_history = await self.get_conversation_history(
            conversation_id, find_all=True
        )

        if conversation_history["conversation_history"] == "":
            conversation_object.summary = ""
            updated_converation: Conversation = (
                await self.conversation_repository.update(
                    conversation_object.id, conversation_object
                )
            )
            return {"conversation_summary": updated_converation.summary}

        async with self.openai_client() as client:
            conversation_summary = await client.text_generation(
                query=f"Generate Conversation Summary\n\n {conversation_history}"
            )
        conversation_object.summary = conversation_summary
        updated_converation: Conversation = await self.conversation_repository.update(
            conversation_object.id, conversation_object
        )
        return {"conversation_summary": updated_converation.summary}