ramanjitsingh1368 commited on
Commit
2efe331
·
1 Parent(s): 78c825a

Implement WebSocket service for handling conversations and integrate OpenAI client for text generation

Browse files
src/services/_conversation_service.py CHANGED
@@ -1,28 +1,18 @@
1
- import json
2
  from fastapi import WebSocket
3
  from bson import ObjectId
4
- from uuid import UUID
5
  from beanie import WriteRules
6
 
7
- from ._file_service import FileService
8
- from src.utils import OpenAIClient
9
- from src.repositories import ConversationRepository, MessageRepository
10
- from src.models import Conversation, Message, User
11
  from src.config import logger
 
 
 
 
12
 
13
 
14
  class ConversationService:
15
  def __init__(self):
16
  self.openai_client = OpenAIClient
17
- self.conversation_repository = ConversationRepository()
18
- self.message_repository = MessageRepository()
19
- self.file_service = FileService
20
- self.functions_dictionary = {
21
- "get_relevant_information": {
22
- "service": self.file_service,
23
- "function": "semantic_search",
24
- },
25
- }
26
 
27
  async def __aenter__(self):
28
  return self
@@ -53,229 +43,5 @@ class ConversationService:
53
  )
54
 
55
  async def conversation(self, websocket: WebSocket):
56
- try:
57
- query_params = websocket.query_params
58
- conversation_id = query_params.get("conversation_id")
59
- modality = query_params.get("modality")
60
-
61
- await websocket.accept()
62
-
63
- if not conversation_id or not modality:
64
- await websocket.close(
65
- code=1008, reason="Missing or invalid query parameters"
66
- )
67
- return
68
-
69
- if modality == "text":
70
- await self.handle_text_conversion(websocket, conversation_id)
71
- elif modality == "voice":
72
- await self.handle_voice_conversion(websocket, conversation_id)
73
- else:
74
- await websocket.close(code=1008, reason="Unsupported modality")
75
- return
76
- finally:
77
- await self.handle_conversation_summary(websocket, conversation_id)
78
- # await websocket.close()
79
-
80
- async def handle_text_conversion(self, websocket, conversation_id):
81
- while True:
82
- data = await websocket.receive_text()
83
- message = json.loads(data)
84
-
85
- openai_server_events = {
86
- "session_created": message["type"] == "session.created",
87
- "function_call": (
88
- message["type"] == "response.function_call_arguments.done"
89
- ),
90
- "ai_response": (
91
- message["type"] == "response.done"
92
- and message["response"]["status"] == "completed"
93
- and message["response"]["output"][0]["type"] == "message"
94
- and message["response"]["output"][0]["role"] == "assistant"
95
- ),
96
- }
97
-
98
- if openai_server_events["session_created"]:
99
- user_query = input("Enter your query: ")
100
- await self.handle_user_message(
101
- message_content=user_query, conversation_id=conversation_id
102
- )
103
-
104
- event_stage_1 = {
105
- "type": "conversation.item.create",
106
- "previous_item_id": None,
107
- "item": {
108
- "type": "message",
109
- "role": "user",
110
- "content": [
111
- {
112
- "type": "input_text",
113
- "text": user_query,
114
- }
115
- ],
116
- },
117
- }
118
- await websocket.send_text(json.dumps(event_stage_1))
119
- event_response_2 = {"type": "response.create"}
120
- await websocket.send_text(json.dumps(event_response_2))
121
-
122
- if openai_server_events["function_call"]:
123
- response = await self.handle_ai_function_call(message)
124
- event_response = {
125
- "type": "conversation.item.create",
126
- "previous_item_id": None,
127
- "item": {
128
- "type": "message",
129
- "role": "user",
130
- "content": [
131
- {
132
- "type": "input_text",
133
- "text": response,
134
- }
135
- ],
136
- },
137
- }
138
- await websocket.send_text(json.dumps(event_response))
139
- event_response_2 = {"type": "response.create"}
140
- await websocket.send_text(json.dumps(event_response_2))
141
-
142
- if openai_server_events["ai_response"]:
143
- ai_response = message["response"]["output"][0]["content"][0]["text"]
144
- await self.handle_ai_message(
145
- message_content=ai_response, conversation_id=conversation_id
146
- )
147
-
148
- user_query = input("Enter your query: ")
149
- await self.handle_user_message(
150
- message_content=user_query, conversation_id=conversation_id
151
- )
152
-
153
- event_stage_1 = {
154
- "type": "conversation.item.create",
155
- "previous_item_id": None,
156
- "item": {
157
- "type": "message",
158
- "role": "user",
159
- "content": [
160
- {
161
- "type": "input_text",
162
- "text": user_query,
163
- }
164
- ],
165
- },
166
- }
167
- await websocket.send_text(json.dumps(event_stage_1))
168
- event_response_2 = {"type": "response.create"}
169
- await websocket.send_text(json.dumps(event_response_2))
170
-
171
- elif message["type"] == "error":
172
- logger.error(f"Error: {message['error']}")
173
-
174
- async def handle_voice_conversion(self, websocket, conversation_id):
175
- while True:
176
- data = await websocket.receive_text()
177
- message = json.loads(data)
178
-
179
- openai_server_events = {
180
- "function_call": (
181
- message["type"] == "response.function_call_arguments.done"
182
- ),
183
- "input_transcription": (
184
- message["type"]
185
- == "conversation.item.input_audio_transcription.completed"
186
- ),
187
- "output_transcription": (
188
- message["type"] == "response.done"
189
- and message["response"]["status"] == "completed"
190
- and message["response"]["output"][0]["type"] == "message"
191
- ),
192
- }
193
-
194
- if openai_server_events["function_call"]:
195
- response = await self.handle_ai_function_call(message)
196
- event_response = {
197
- "type": "conversation.item.create",
198
- "previous_item_id": None,
199
- "item": {
200
- "type": "message",
201
- "role": "user",
202
- "content": [
203
- {
204
- "type": "input_text",
205
- "text": response,
206
- }
207
- ],
208
- },
209
- }
210
- await websocket.send_text(json.dumps(event_response))
211
- event_response_2 = {"type": "response.create"}
212
- await websocket.send_text(json.dumps(event_response_2))
213
-
214
- elif openai_server_events["input_transcription"]:
215
- user_message = message["transcript"]
216
- await self.handle_user_message(
217
- message_content=user_message, conversation_id=conversation_id
218
- )
219
-
220
- elif openai_server_events["output_transcription"]:
221
- ai_response = message["response"]["output"][0]["content"][0][
222
- "transcript"
223
- ]
224
- await self.handle_ai_message(
225
- message_content=ai_response, conversation_id=conversation_id
226
- )
227
-
228
- elif message["type"] == "error":
229
- logger.error(f"Error: {message['error']}")
230
-
231
- async def handle_user_message(self, message_content, conversation_id):
232
- logger.info(f"User Query: {message_content}")
233
- conversation_object = await Conversation.get(ObjectId(conversation_id))
234
- message_object = Message(
235
- conversation=conversation_object,
236
- role="user",
237
- content=message_content,
238
- )
239
- return await message_object.save(link_rule=WriteRules.WRITE)
240
-
241
- async def handle_ai_message(self, message_content, conversation_id):
242
- logger.info(f"AI Response: {message_content}")
243
- conversation_object = await Conversation.get(ObjectId(conversation_id))
244
- message_object = Message(
245
- conversation=conversation_object,
246
- role="assistant",
247
- content=message_content,
248
- )
249
- return await message_object.save(link_rule=WriteRules.WRITE)
250
-
251
- async def handle_conversation_summary(self, conversation_id):
252
- conversation_object = await Conversation.get(ObjectId(conversation_id))
253
- messages: list[Message] = await Message.find_many(
254
- Message.conversation.id == conversation_object.id,
255
- fetch_links=True,
256
- ).to_list()
257
- conversation_summary = "\n".join(
258
- [f"{message.role}: {message.content}" for message in messages]
259
- )
260
- conversation_object.summary = conversation_summary
261
- await conversation_object.save(link_rule=WriteRules.WRITE)
262
-
263
- async def handle_ai_function_call(self, data):
264
- tool_name = data["name"]
265
- arguments = json.loads(data["arguments"])
266
- logger.info(f"Function call: {tool_name} \nArguments: {arguments}")
267
-
268
- function_info = self.functions_dictionary.get(tool_name)
269
- if not function_info:
270
- raise AttributeError(f"Function {tool_name} not found in dictionary")
271
-
272
- service_class = function_info["service"]
273
- async with service_class() as service:
274
- func = getattr(service, function_info["function"], None)
275
- if not func:
276
- raise AttributeError(
277
- f"No such function {function_info['function']} in service"
278
- )
279
- response = await func(arguments["query"])
280
- response = f"Context : {response}\n\n# Use this context to respond to the user query : {arguments['query']}"
281
- return response
 
 
1
  from fastapi import WebSocket
2
  from bson import ObjectId
 
3
  from beanie import WriteRules
4
 
 
 
 
 
5
  from src.config import logger
6
+ from src.utils import OpenAIClient
7
+ from src.models import Conversation, User
8
+
9
+ from ._websocket_service import WebSocketService
10
 
11
 
12
  class ConversationService:
13
  def __init__(self):
14
  self.openai_client = OpenAIClient
15
+ self.websocket_service = WebSocketService
 
 
 
 
 
 
 
 
16
 
17
  async def __aenter__(self):
18
  return self
 
43
  )
44
 
45
  async def conversation(self, websocket: WebSocket):
46
+ async with self.websocket_service() as ws_service:
47
+ await ws_service.handle_conversation(websocket)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/services/_websocket_service.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from fastapi import WebSocket
3
+ from bson import ObjectId
4
+ from beanie import WriteRules
5
+
6
+ from ._file_service import FileService
7
+ from src.utils import OpenAIClient
8
+ from src.repositories import ConversationRepository, MessageRepository
9
+ from src.models import Conversation, Message
10
+ from src.config import logger
11
+
12
+
13
+ class WebSocketService:
14
+ def __init__(self):
15
+ self.openai_client = OpenAIClient
16
+ self.conversation_repository = ConversationRepository()
17
+ self.message_repository = MessageRepository()
18
+ self.file_service = FileService
19
+ self.functions_dictionary = {
20
+ "get_relevant_information": {
21
+ "service": self.file_service,
22
+ "function": "semantic_search",
23
+ },
24
+ }
25
+
26
+ async def __aenter__(self):
27
+ return self
28
+
29
+ async def __aexit__(self, *args):
30
+ pass
31
+
32
+ async def handle_conversation(self, websocket: WebSocket):
33
+ try:
34
+ query_params = websocket.query_params
35
+ conversation_id = query_params.get("conversation_id")
36
+ modality = query_params.get("modality")
37
+
38
+ await websocket.accept()
39
+
40
+ if not conversation_id or not modality:
41
+ await websocket.close(
42
+ code=1008, reason="Missing or invalid query parameters"
43
+ )
44
+ return
45
+
46
+ if modality == "text":
47
+ await self.handle_text_conversion(websocket, conversation_id)
48
+ elif modality == "voice":
49
+ await self.handle_voice_conversion(websocket, conversation_id)
50
+ else:
51
+ await websocket.close(code=1008, reason="Unsupported modality")
52
+ return
53
+ finally:
54
+ await self.handle_conversation_summary(conversation_id)
55
+
56
+ async def handle_text_conversion(self, websocket, conversation_id):
57
+ while True:
58
+ data = await websocket.receive_text()
59
+ message = json.loads(data)
60
+
61
+ openai_server_events = {
62
+ "session_created": message["type"] == "session.created",
63
+ "function_call": (
64
+ message["type"] == "response.function_call_arguments.done"
65
+ ),
66
+ "ai_response": (
67
+ message["type"] == "response.done"
68
+ and message["response"]["status"] == "completed"
69
+ and message["response"]["output"][0]["type"] == "message"
70
+ and message["response"]["output"][0]["role"] == "assistant"
71
+ ),
72
+ }
73
+
74
+ if openai_server_events["session_created"]:
75
+ user_query = input("Enter your query: ")
76
+ await self.handle_user_message(
77
+ message_content=user_query, conversation_id=conversation_id
78
+ )
79
+
80
+ event_stage_1 = {
81
+ "type": "conversation.item.create",
82
+ "previous_item_id": None,
83
+ "item": {
84
+ "type": "message",
85
+ "role": "user",
86
+ "content": [
87
+ {
88
+ "type": "input_text",
89
+ "text": user_query,
90
+ }
91
+ ],
92
+ },
93
+ }
94
+ await websocket.send_text(json.dumps(event_stage_1))
95
+ event_response_2 = {"type": "response.create"}
96
+ await websocket.send_text(json.dumps(event_response_2))
97
+
98
+ if openai_server_events["function_call"]:
99
+ response = await self.handle_ai_function_call(message)
100
+ event_response = {
101
+ "type": "conversation.item.create",
102
+ "previous_item_id": None,
103
+ "item": {
104
+ "type": "message",
105
+ "role": "user",
106
+ "content": [
107
+ {
108
+ "type": "input_text",
109
+ "text": response,
110
+ }
111
+ ],
112
+ },
113
+ }
114
+ await websocket.send_text(json.dumps(event_response))
115
+ event_response_2 = {"type": "response.create"}
116
+ await websocket.send_text(json.dumps(event_response_2))
117
+
118
+ if openai_server_events["ai_response"]:
119
+ ai_response = message["response"]["output"][0]["content"][0]["text"]
120
+ await self.handle_ai_message(
121
+ message_content=ai_response, conversation_id=conversation_id
122
+ )
123
+
124
+ user_query = input("Enter your query: ")
125
+ await self.handle_user_message(
126
+ message_content=user_query, conversation_id=conversation_id
127
+ )
128
+
129
+ event_stage_1 = {
130
+ "type": "conversation.item.create",
131
+ "previous_item_id": None,
132
+ "item": {
133
+ "type": "message",
134
+ "role": "user",
135
+ "content": [
136
+ {
137
+ "type": "input_text",
138
+ "text": user_query,
139
+ }
140
+ ],
141
+ },
142
+ }
143
+ await websocket.send_text(json.dumps(event_stage_1))
144
+ event_response_2 = {"type": "response.create"}
145
+ await websocket.send_text(json.dumps(event_response_2))
146
+
147
+ elif message["type"] == "error":
148
+ logger.error(f"Error: {message['error']}")
149
+
150
+ async def handle_voice_conversion(self, websocket, conversation_id):
151
+ while True:
152
+ data = await websocket.receive_text()
153
+ message = json.loads(data)
154
+
155
+ openai_server_events = {
156
+ "function_call": (
157
+ message["type"] == "response.function_call_arguments.done"
158
+ ),
159
+ "input_transcription": (
160
+ message["type"]
161
+ == "conversation.item.input_audio_transcription.completed"
162
+ ),
163
+ "output_transcription": (
164
+ message["type"] == "response.done"
165
+ and message["response"]["status"] == "completed"
166
+ and message["response"]["output"][0]["type"] == "message"
167
+ ),
168
+ }
169
+
170
+ if openai_server_events["function_call"]:
171
+ response = await self.handle_ai_function_call(message)
172
+ event_response = {
173
+ "type": "conversation.item.create",
174
+ "previous_item_id": None,
175
+ "item": {
176
+ "type": "message",
177
+ "role": "user",
178
+ "content": [
179
+ {
180
+ "type": "input_text",
181
+ "text": response,
182
+ }
183
+ ],
184
+ },
185
+ }
186
+ await websocket.send_text(json.dumps(event_response))
187
+ event_response_2 = {"type": "response.create"}
188
+ await websocket.send_text(json.dumps(event_response_2))
189
+
190
+ elif openai_server_events["input_transcription"]:
191
+ user_message = message["transcript"]
192
+ await self.handle_user_message(
193
+ message_content=user_message, conversation_id=conversation_id
194
+ )
195
+
196
+ elif openai_server_events["output_transcription"]:
197
+ ai_response = message["response"]["output"][0]["content"][0][
198
+ "transcript"
199
+ ]
200
+ await self.handle_ai_message(
201
+ message_content=ai_response, conversation_id=conversation_id
202
+ )
203
+
204
+ elif message["type"] == "error":
205
+ logger.error(f"Error: {message['error']}")
206
+
207
+ async def handle_user_message(self, message_content, conversation_id):
208
+ logger.info(f"User Query: {message_content}")
209
+ conversation_object = await Conversation.get(ObjectId(conversation_id))
210
+ message_object = Message(
211
+ conversation=conversation_object,
212
+ role="user",
213
+ content=message_content,
214
+ )
215
+ return await message_object.save(link_rule=WriteRules.WRITE)
216
+
217
+ async def handle_ai_message(self, message_content, conversation_id):
218
+ logger.info(f"AI Response: {message_content}")
219
+ conversation_object = await Conversation.get(ObjectId(conversation_id))
220
+ message_object = Message(
221
+ conversation=conversation_object,
222
+ role="assistant",
223
+ content=message_content,
224
+ )
225
+ return await message_object.save(link_rule=WriteRules.WRITE)
226
+
227
+ async def handle_conversation_summary(self, conversation_id):
228
+ conversation_object = await Conversation.get(ObjectId(conversation_id))
229
+ messages: list[Message] = await Message.find_many(
230
+ Message.conversation.id == conversation_object.id,
231
+ fetch_links=True,
232
+ ).to_list()
233
+
234
+ conversation_history = "\n".join(
235
+ [f"{message.role}: {message.content}" for message in messages]
236
+ )
237
+ query = "Generate Conversation Summary\n\n" + conversation_history
238
+ async with self.openai_client() as client:
239
+ conversation_summary = await client.text_generation(query=query)
240
+
241
+ logger.info(f"Conversation Summary: {conversation_summary}")
242
+
243
+ conversation_object.summary = conversation_summary
244
+ await conversation_object.save(link_rule=WriteRules.WRITE)
245
+
246
+ async def handle_ai_function_call(self, data):
247
+ tool_name = data["name"]
248
+ arguments = json.loads(data["arguments"])
249
+ logger.info(f"Function call: {tool_name} \nArguments: {arguments}")
250
+
251
+ function_info = self.functions_dictionary.get(tool_name)
252
+ if not function_info:
253
+ raise AttributeError(f"Function {tool_name} not found in dictionary")
254
+
255
+ service_class = function_info["service"]
256
+ async with service_class() as service:
257
+ func = getattr(service, function_info["function"], None)
258
+ if not func:
259
+ raise AttributeError(
260
+ f"No such function {function_info['function']} in service"
261
+ )
262
+ response = await func(arguments["query"])
263
+ response = f"Context : {response}\n\n# Use this context to respond to the user query : {arguments['query']}"
264
+ return response
src/utils/_openai_client.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  import httpx
3
  from fastapi import HTTPException
4
  from aiortc import RTCPeerConnection, RTCSessionDescription
 
5
 
6
  from src.config import logger
7
 
@@ -27,6 +28,7 @@ _OPENAI_TOOLS = [
27
 
28
  class OpenAIClient:
29
  def __init__(self):
 
30
  self.session_url = "https://api.openai.com/v1/realtime/sessions"
31
  self.model = "gpt-4o-mini-realtime-preview-2024-12-17"
32
  self.webrtc_url = f"https://api.openai.com/v1/realtime?model={self.model}"
@@ -40,6 +42,7 @@ class OpenAIClient:
40
 
41
  async def __aenter__(self):
42
  self.system_prompt = await self.prompt_loader("system_prompt.md")
 
43
  return self
44
 
45
  async def __aexit__(self, *args):
@@ -51,6 +54,19 @@ class OpenAIClient:
51
  prompt = file.read()
52
  return prompt
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  async def create_openai_session(self, text_mode_only=False):
55
  headers = {
56
  "Authorization": f'Bearer {os.getenv("OPENAI_API_KEY")}',
 
2
  import httpx
3
  from fastapi import HTTPException
4
  from aiortc import RTCPeerConnection, RTCSessionDescription
5
+ from openai import AsyncOpenAI
6
 
7
  from src.config import logger
8
 
 
28
 
29
  class OpenAIClient:
30
  def __init__(self):
31
+ self.client = None
32
  self.session_url = "https://api.openai.com/v1/realtime/sessions"
33
  self.model = "gpt-4o-mini-realtime-preview-2024-12-17"
34
  self.webrtc_url = f"https://api.openai.com/v1/realtime?model={self.model}"
 
42
 
43
  async def __aenter__(self):
44
  self.system_prompt = await self.prompt_loader("system_prompt.md")
45
+ self.client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"))
46
  return self
47
 
48
  async def __aexit__(self, *args):
 
54
  prompt = file.read()
55
  return prompt
56
 
57
+ async def text_generation(self, query: str):
58
+
59
+ completion = await self.client.chat.completions.create(
60
+ model="gpt-4o-mini-2024-07-18",
61
+ messages=[
62
+ {
63
+ "role": "user",
64
+ "content": query,
65
+ },
66
+ ],
67
+ )
68
+ return completion.choices[0].message.content
69
+
70
  async def create_openai_session(self, text_mode_only=False):
71
  headers = {
72
  "Authorization": f'Bearer {os.getenv("OPENAI_API_KEY")}',