ramanjitsingh1368 commited on
Commit
1e0a76d
·
1 Parent(s): be1d374

Improve file and conversation services to use repository pattern

Browse files
src/controllers/_file_controller.py CHANGED
@@ -15,7 +15,7 @@ class FileController:
15
 
16
  async def insert_file(self, file: UploadFile = File(...)):
17
  try:
18
- user_id = "67987a511dd18e6ae8d0671c"
19
  async with self.service() as service:
20
  return await service.insert_file(file=file, user_id=user_id)
21
  except Exception as e:
 
15
 
16
  async def insert_file(self, file: UploadFile = File(...)):
17
  try:
18
+ user_id = "67987b4a7ce47dbbdcbb28f1"
19
  async with self.service() as service:
20
  return await service.insert_file(file=file, user_id=user_id)
21
  except Exception as e:
src/middlewares/_authentication.py CHANGED
@@ -15,6 +15,7 @@ EXEMPT_ROUTES = [
15
  "/api/v1/auth/logout",
16
  "/api/v1/auth/register",
17
  "/api/v1/conversations",
 
18
  ]
19
 
20
  EXEMPT_ROUTE_PATTERNS = [
 
15
  "/api/v1/auth/logout",
16
  "/api/v1/auth/register",
17
  "/api/v1/conversations",
18
+ "/api/v1/files",
19
  ]
20
 
21
  EXEMPT_ROUTE_PATTERNS = [
src/repositories/_base_repository.py CHANGED
@@ -8,13 +8,19 @@ class BaseRepository:
8
  def __init__(self, model: Type[T]):
9
  self.model = model
10
 
11
- async def insert_one(self, data: T):
12
  try:
13
  return await data.insert(link_rule=WriteRules.WRITE)
14
  except Exception as e:
15
  raise ValueError(f"Failed to insert data: {str(e)}")
16
 
17
- async def get_all(self, page: int = 1, page_size: int = None, filter_by: Dict = None, order_by: Dict = None):
 
 
 
 
 
 
18
  query = filter_by or {}
19
  cursor = self.model.find(query, fetch_links=True)
20
 
@@ -27,7 +33,7 @@ class BaseRepository:
27
 
28
  return await cursor.to_list(length=page_size if page_size else None)
29
 
30
- async def get_by_id(self, _id: str):
31
  return await self.model.get(_id, fetch_links=True)
32
 
33
  async def update(self, _id: str, update_data: dict):
@@ -37,6 +43,7 @@ class BaseRepository:
37
 
38
  updated_fields = {**document.model_dump(), **update_data.model_dump()}
39
  return await document.set(updated_fields)
 
40
  async def delete(self, _id: str):
41
 
42
  document = await self.model.get(_id)
@@ -45,4 +52,4 @@ class BaseRepository:
45
  return await document.delete(link_rule=DeleteRules.DELETE)
46
 
47
  async def count_documents(self):
48
- return await self.model.count()
 
8
  def __init__(self, model: Type[T]):
9
  self.model = model
10
 
11
+ async def insert_one(self, data: T) -> T:
12
  try:
13
  return await data.insert(link_rule=WriteRules.WRITE)
14
  except Exception as e:
15
  raise ValueError(f"Failed to insert data: {str(e)}")
16
 
17
+ async def get_all(
18
+ self,
19
+ page: int = 1,
20
+ page_size: int = None,
21
+ filter_by: Dict = None,
22
+ order_by: Dict = None,
23
+ ):
24
  query = filter_by or {}
25
  cursor = self.model.find(query, fetch_links=True)
26
 
 
33
 
34
  return await cursor.to_list(length=page_size if page_size else None)
35
 
36
+ async def get_by_id(self, _id: str) -> T:
37
  return await self.model.get(_id, fetch_links=True)
38
 
39
  async def update(self, _id: str, update_data: dict):
 
43
 
44
  updated_fields = {**document.model_dump(), **update_data.model_dump()}
45
  return await document.set(updated_fields)
46
+
47
  async def delete(self, _id: str):
48
 
49
  document = await self.model.get(_id)
 
52
  return await document.delete(link_rule=DeleteRules.DELETE)
53
 
54
  async def count_documents(self):
55
+ return await self.model.count()
src/repositories/_message_repository.py CHANGED
@@ -1,4 +1,4 @@
1
- from src.models import Message
2
 
3
  from ._base_repository import BaseRepository
4
 
@@ -6,3 +6,19 @@ from ._base_repository import BaseRepository
6
  class MessageRepository(BaseRepository):
7
  def __init__(self):
8
  super().__init__(model=Message)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.models import Message, Conversation
2
 
3
  from ._base_repository import BaseRepository
4
 
 
6
  class MessageRepository(BaseRepository):
7
  def __init__(self):
8
  super().__init__(model=Message)
9
+ self.model = Message
10
+
11
+ async def get_messages_by_conversation(
12
+ self, conversation: Conversation
13
+ ) -> list[Message]:
14
+ try:
15
+ messages = await self.model.find_many(
16
+ self.model.conversation.id == conversation.id, fetch_links=True
17
+ ).to_list()
18
+
19
+ if messages:
20
+ return messages
21
+ else:
22
+ raise ValueError("Messages not found")
23
+ except Exception as e:
24
+ raise ValueError(f"Failed to get messages: {str(e)}")
src/services/_conversation_service.py CHANGED
@@ -1,10 +1,9 @@
1
  from fastapi import WebSocket
2
- from bson import ObjectId
3
- from beanie import WriteRules
4
 
5
  from src.config import logger
6
  from src.utils import OpenAIClient
7
  from src.models import Conversation, User, Message
 
8
 
9
  from ._websocket_service import WebSocketService
10
 
@@ -13,6 +12,9 @@ class ConversationService:
13
  def __init__(self):
14
  self.openai_client = OpenAIClient
15
  self.websocket_service = WebSocketService
 
 
 
16
 
17
  async def __aenter__(self):
18
  return self
@@ -21,10 +23,11 @@ class ConversationService:
21
  pass
22
 
23
  async def create_conversation(self, user_id, modality):
24
- user_object = await User.get(ObjectId(user_id))
25
- conversation_object = Conversation(user=user_object, summary="")
26
- conversation = await conversation_object.save(link_rule=WriteRules.WRITE)
27
- print("Conversation ID: ", conversation.id)
 
28
 
29
  text_mode_only = True if modality == "text" else False
30
  async with self.openai_client() as client:
@@ -33,7 +36,7 @@ class ConversationService:
33
  )
34
 
35
  return {
36
- "conversation_id": str(conversation.id),
37
  "session_data": session_data,
38
  }
39
 
@@ -48,11 +51,14 @@ class ConversationService:
48
  await ws_service.handle_conversation(websocket)
49
 
50
  async def create_conversation_summary(self, conversation_id):
51
- conversation_object = await Conversation.get(ObjectId(conversation_id))
52
- messages: list[Message] = await Message.find_many(
53
- Message.conversation.id == conversation_object.id,
54
- fetch_links=True,
55
- ).to_list()
 
 
 
56
 
57
  conversation_history = "\n".join(
58
  [f"{message.role}: {message.content}" for message in messages]
@@ -64,6 +70,7 @@ class ConversationService:
64
  )
65
 
66
  conversation_object.summary = conversation_summary
67
- await conversation_object.save(link_rule=WriteRules.WRITE)
68
-
69
- return {"conversation_summary": conversation_summary}
 
 
1
  from fastapi import WebSocket
 
 
2
 
3
  from src.config import logger
4
  from src.utils import OpenAIClient
5
  from src.models import Conversation, User, Message
6
+ from src.repositories import ConversationRepository, MessageRepository, UserRepository
7
 
8
  from ._websocket_service import WebSocketService
9
 
 
12
  def __init__(self):
13
  self.openai_client = OpenAIClient
14
  self.websocket_service = WebSocketService
15
+ self.conversation_repository = ConversationRepository()
16
+ self.message_repository = MessageRepository()
17
+ self.user_repository = UserRepository()
18
 
19
  async def __aenter__(self):
20
  return self
 
23
  pass
24
 
25
  async def create_conversation(self, user_id, modality):
26
+ user_object = await self.user_repository.get_by_id(user_id)
27
+ conversation_object = await self.conversation_repository.insert_one(
28
+ Conversation(user=user_object, summary="")
29
+ )
30
+ print("Conversation ID: ", conversation_object.id)
31
 
32
  text_mode_only = True if modality == "text" else False
33
  async with self.openai_client() as client:
 
36
  )
37
 
38
  return {
39
+ "conversation_id": str(conversation_object.id),
40
  "session_data": session_data,
41
  }
42
 
 
51
  await ws_service.handle_conversation(websocket)
52
 
53
  async def create_conversation_summary(self, conversation_id):
54
+ conversation_object: Conversation = (
55
+ await self.conversation_repository.get_by_id(conversation_id)
56
+ )
57
+ messages: list[Message] = (
58
+ await self.message_repository.get_messages_by_conversation(
59
+ conversation_object
60
+ )
61
+ )
62
 
63
  conversation_history = "\n".join(
64
  [f"{message.role}: {message.content}" for message in messages]
 
70
  )
71
 
72
  conversation_object.summary = conversation_summary
73
+ updated_converation: Conversation = await self.conversation_repository.update(
74
+ conversation_object.id, conversation_object
75
+ )
76
+ return {"conversation_summary": updated_converation.summary}
src/services/_file_service.py CHANGED
@@ -1,14 +1,18 @@
1
  from fastapi import UploadFile
2
- from beanie import WriteRules
3
 
4
- from src.repositories import FileRepository, VectorStoreRecordIdRepository
5
- from src.models import File, VectorStoreRecordId
6
  from src.utils import PineconeClient
 
 
 
 
 
 
7
 
8
 
9
  class FileService:
10
  def __init__(self):
11
  self.pinecone_client = PineconeClient
 
12
  self.file_repository = FileRepository()
13
  self.vector_store_record_id_repository = VectorStoreRecordIdRepository()
14
 
@@ -20,6 +24,8 @@ class FileService:
20
 
21
  async def insert_file(self, file: UploadFile, user_id: str):
22
  try:
 
 
23
  metadata = [{"file_name": file.filename}]
24
  while content := await file.read():
25
  async with self.pinecone_client() as client:
@@ -27,18 +33,18 @@ class FileService:
27
  text=[content.decode()], metadata=metadata
28
  )
29
 
30
- file_object = File(
31
- user_id=user_id,
32
- file_name=file.filename,
33
- file_type=file.content_type,
 
 
34
  )
35
- await file_object.save(link_rule=WriteRules.WRITE)
36
 
37
  for id in ids:
38
- vector_store_record_id_object = VectorStoreRecordId(
39
- file_id=file_object.id, record_id=id
40
  )
41
- await vector_store_record_id_object.save(link_rule=WriteRules.WRITE)
42
  return file_object
43
 
44
  finally:
 
1
  from fastapi import UploadFile
 
2
 
 
 
3
  from src.utils import PineconeClient
4
+ from src.models import File, VectorStoreRecordId
5
+ from src.repositories import (
6
+ FileRepository,
7
+ VectorStoreRecordIdRepository,
8
+ UserRepository,
9
+ )
10
 
11
 
12
  class FileService:
13
  def __init__(self):
14
  self.pinecone_client = PineconeClient
15
+ self.user_repository = UserRepository()
16
  self.file_repository = FileRepository()
17
  self.vector_store_record_id_repository = VectorStoreRecordIdRepository()
18
 
 
24
 
25
  async def insert_file(self, file: UploadFile, user_id: str):
26
  try:
27
+ user_object = await self.user_repository.get_by_id(user_id)
28
+
29
  metadata = [{"file_name": file.filename}]
30
  while content := await file.read():
31
  async with self.pinecone_client() as client:
 
33
  text=[content.decode()], metadata=metadata
34
  )
35
 
36
+ file_object = await self.file_repository.insert_one(
37
+ File(
38
+ user=user_object,
39
+ file_name=file.filename,
40
+ file_type=file.content_type,
41
+ )
42
  )
 
43
 
44
  for id in ids:
45
+ await self.vector_store_record_id_repository.insert_one(
46
+ VectorStoreRecordId(file=file_object, record_id=id)
47
  )
 
48
  return file_object
49
 
50
  finally:
src/services/_websocket_service.py CHANGED
@@ -1,13 +1,12 @@
1
  import json
2
  from fastapi import WebSocket
3
- from bson import ObjectId
4
- from beanie import WriteRules
5
 
6
- from ._file_service import FileService
7
  from src.utils import OpenAIClient
8
- from src.repositories import ConversationRepository, MessageRepository
9
- from src.models import Conversation, Message
10
  from src.config import logger
 
 
 
11
 
12
 
13
  class WebSocketService:
@@ -203,23 +202,29 @@ class WebSocketService:
203
 
204
  async def handle_user_message(self, message_content, conversation_id):
205
  logger.info(f"User Query: {message_content}")
206
- conversation_object = await Conversation.get(ObjectId(conversation_id))
207
- message_object = Message(
208
- conversation=conversation_object,
209
- role="user",
210
- content=message_content,
 
 
 
 
211
  )
212
- return await message_object.save(link_rule=WriteRules.WRITE)
213
 
214
  async def handle_ai_message(self, message_content, conversation_id):
215
  logger.info(f"AI Response: {message_content}")
216
- conversation_object = await Conversation.get(ObjectId(conversation_id))
217
- message_object = Message(
218
- conversation=conversation_object,
219
- role="assistant",
220
- content=message_content,
 
 
 
 
221
  )
222
- return await message_object.save(link_rule=WriteRules.WRITE)
223
 
224
  async def handle_ai_function_call(self, data):
225
  tool_name = data["name"]
 
1
  import json
2
  from fastapi import WebSocket
 
 
3
 
 
4
  from src.utils import OpenAIClient
5
+ from src.models import Message
 
6
  from src.config import logger
7
+ from src.repositories import ConversationRepository, MessageRepository
8
+
9
+ from ._file_service import FileService
10
 
11
 
12
  class WebSocketService:
 
202
 
203
  async def handle_user_message(self, message_content, conversation_id):
204
  logger.info(f"User Query: {message_content}")
205
+ conversation_object = await self.conversation_repository.get_by_id(
206
+ conversation_id
207
+ )
208
+ await self.message_repository.insert_one(
209
+ Message(
210
+ conversation=conversation_object,
211
+ role="user",
212
+ content=message_content,
213
+ )
214
  )
 
215
 
216
  async def handle_ai_message(self, message_content, conversation_id):
217
  logger.info(f"AI Response: {message_content}")
218
+ conversation_object = await self.conversation_repository.get_by_id(
219
+ conversation_id
220
+ )
221
+ await self.message_repository.insert_one(
222
+ Message(
223
+ conversation=conversation_object,
224
+ role="assistant",
225
+ content=message_content,
226
+ )
227
  )
 
228
 
229
  async def handle_ai_function_call(self, data):
230
  tool_name = data["name"]