File size: 5,942 Bytes
a3aa6c1
eb474ee
 
 
 
14bfd4c
eb474ee
d4f6849
af7af6c
eb474ee
 
af7af6c
 
 
 
 
 
 
 
a3aa6c1
eb474ee
 
d4f6849
 
272c2c0
d4f6849
 
a3aa6c1
d4f6849
a3aa6c1
07aae89
00662bb
a3aa6c1
b8bc03f
a3aa6c1
 
 
d4f6849
a3aa6c1
b8bc03f
a3aa6c1
00662bb
 
 
a3aa6c1
b8bc03f
a3aa6c1
 
 
00662bb
a3aa6c1
00662bb
 
a3aa6c1
00662bb
 
a3aa6c1
00662bb
a3aa6c1
 
d4f6849
 
eb474ee
 
d4f6849
a3aa6c1
eb474ee
 
269b3c2
eb474ee
af7af6c
 
 
 
eb474ee
 
a3aa6c1
eb474ee
 
af7af6c
eb474ee
be1d374
 
 
 
af7af6c
be1d374
269b3c2
eb474ee
fcfe4b3
eb474ee
 
 
b8bc03f
eb474ee
a3aa6c1
 
 
eb474ee
 
 
 
 
 
af7af6c
eb474ee
 
 
 
af7af6c
 
eb474ee
a3aa6c1
 
 
eb474ee
 
 
 
 
 
1c4497b
b8bc03f
a3aa6c1
 
00662bb
 
eb474ee
 
af7af6c
 
 
 
a3aa6c1
af7af6c
6fb66ca
a3aa6c1
 
 
 
 
6fb66ca
 
a3aa6c1
 
 
eb474ee
 
a3aa6c1
 
 
 
 
be1d374
fcfe4b3
be1d374
 
af7af6c
b8bc03f
af7af6c
a3aa6c1
 
 
 
be1d374
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import os
from fastapi import (
    APIRouter,
    HTTPException,
    WebSocket,
    WebSocketDisconnect,
)
from typing import List

from src.config import logger
from src.services import ConversationService
from src.schemas import (
    CreateConversationSchema,
    CreateWebrtcConnectionSchema,
    CreateConversationSummarySchema,
    CreateConversationResponse,
    CreateConversationSummaryResponse,
    CreateWebrtcConnectionResponse,
)
from src.utils import JWTUtil, RedisClient


class ConnectionManager:
    def __init__(self):
        self.jwt = JWTUtil()
        self.active_connections: List[WebSocket] = []

    async def connect(self, websocket: WebSocket, redis_client):
        await websocket.accept()
        conversation_id = websocket.query_params.get("conversation_id")

        self.active_connections.append(websocket)
        await self._update_redis_status(
            unique_id=f"{conversation_id}",
            redis_client=redis_client,
            status="active",
        )

    async def disconnect(
        self, websocket: WebSocket, redis_client, conversation_id: str
    ):
        if websocket in self.active_connections:
            self.active_connections.remove(websocket)

        await self._update_redis_status(
            unique_id=f"{conversation_id}",
            redis_client=redis_client,
            status=None,
        )

    async def _update_redis_status(self, unique_id: str, redis_client, status: str):
        if status:
            await redis_client.set(
                f"session:{unique_id}", status, ex=os.getenv("REDIS_SESSION_EXPIRY")
            )
        else:
            await redis_client.delete(f"session:{unique_id}")

        redis_status = await redis_client.get(f"session:{unique_id}")
        logger.info(f"Redis status for user {unique_id}: {redis_status}")


class ConversationController:
    def __init__(self):
        self.websocket_connection_manager = ConnectionManager()
        self.redis_client = RedisClient().client
        self.service = ConversationService
        self.api_router = APIRouter()
        self.websocket_router = APIRouter()
        self.api_router.add_api_route(
            "/conversations",
            self.create_conversation,
            methods=["POST"],
            response_model=CreateConversationResponse,
        )
        self.api_router.add_api_route(
            "/conversations/{conversation_id}",
            self.create_webrtc_connection,
            methods=["POST"],
            response_model=CreateWebrtcConnectionResponse,
        )
        self.api_router.add_api_route(
            "/conversations/{conversation_id}/summary",
            self.create_conversation_summary,
            methods=["POST"],
            response_model=CreateConversationSummaryResponse,
        )
        self.websocket_router.add_websocket_route("/conversations", self.conversation)

    async def create_conversation(self, data: CreateConversationSchema):
        try:
            async with self.service() as service:
                return await service.create_conversation(
                    name=data.name, email=data.email, modality=data.modality
                )
        except HTTPException as e:
            raise

        except Exception as e:
            logger.error(f"Error creating conversation: {str(e)}")
            raise HTTPException(status_code=500, detail="Failed to create conversation")

    async def create_webrtc_connection(
        self,
        data: CreateWebrtcConnectionSchema,
    ):
        try:
            async with self.service() as service:
                return await service.create_webrtc_connection(
                    conversation_id=data.conversation_id,
                    offer=data.offer.model_dump(),
                )
        except HTTPException as e:
            raise

        except Exception as e:
            logger.error(f"Error creating WebRTC connection: {str(e)}")
            raise HTTPException(
                status_code=500, detail="Failed to create WebRTC connection"
            )

    async def conversation(self, websocket: WebSocket):
        await self.websocket_connection_manager.connect(
            websocket=websocket, redis_client=self.redis_client
        )
        conversation_id = websocket.query_params.get("conversation_id")
        modality = websocket.query_params.get("modality")
        try:
            async with self.service() as service:
                await service.conversation(
                    websocket=websocket,
                    conversation_id=conversation_id,
                    modality=modality,
                    redis_client=self.redis_client,
                )
        except WebSocketDisconnect:
            await self.websocket_connection_manager.disconnect(
                websocket=websocket,
                redis_client=self.redis_client,
                conversation_id=conversation_id,
            )
            logger.info("WebSocket connection closed")

        except HTTPException as e:
            raise

        except Exception as e:
            logger.error(f"Error in WebSocket conversation: {str(e)}")
            await self.websocket_connection_manager.disconnect(
                websocket=websocket,
                redis_client=self.redis_client,
                conversation_id=conversation_id,
            )

    async def create_conversation_summary(self, data: CreateConversationSummarySchema):
        try:
            async with self.service() as service:
                return await service.create_conversation_summary(
                    conversation_id=data.conversation_id
                )

        except HTTPException as e:
            raise

        except Exception as e:
            logger.error(f"Error creating conversation summary: {str(e)}")
            raise HTTPException(
                status_code=500, detail="Failed to create conversation summary"
            )