Spaces:
Runtime error
Runtime error
Commit ·
1e0a76d
1
Parent(s): be1d374
Improve file and conversation services to use repository pattern
Browse files- src/controllers/_file_controller.py +1 -1
- src/middlewares/_authentication.py +1 -0
- src/repositories/_base_repository.py +11 -4
- src/repositories/_message_repository.py +17 -1
- src/services/_conversation_service.py +22 -15
- src/services/_file_service.py +17 -11
- src/services/_websocket_service.py +22 -17
src/controllers/_file_controller.py
CHANGED
|
@@ -15,7 +15,7 @@ class FileController:
|
|
| 15 |
|
| 16 |
async def insert_file(self, file: UploadFile = File(...)):
|
| 17 |
try:
|
| 18 |
-
user_id = "
|
| 19 |
async with self.service() as service:
|
| 20 |
return await service.insert_file(file=file, user_id=user_id)
|
| 21 |
except Exception as e:
|
|
|
|
| 15 |
|
| 16 |
async def insert_file(self, file: UploadFile = File(...)):
|
| 17 |
try:
|
| 18 |
+
user_id = "67987b4a7ce47dbbdcbb28f1"
|
| 19 |
async with self.service() as service:
|
| 20 |
return await service.insert_file(file=file, user_id=user_id)
|
| 21 |
except Exception as e:
|
src/middlewares/_authentication.py
CHANGED
|
@@ -15,6 +15,7 @@ EXEMPT_ROUTES = [
|
|
| 15 |
"/api/v1/auth/logout",
|
| 16 |
"/api/v1/auth/register",
|
| 17 |
"/api/v1/conversations",
|
|
|
|
| 18 |
]
|
| 19 |
|
| 20 |
EXEMPT_ROUTE_PATTERNS = [
|
|
|
|
| 15 |
"/api/v1/auth/logout",
|
| 16 |
"/api/v1/auth/register",
|
| 17 |
"/api/v1/conversations",
|
| 18 |
+
"/api/v1/files",
|
| 19 |
]
|
| 20 |
|
| 21 |
EXEMPT_ROUTE_PATTERNS = [
|
src/repositories/_base_repository.py
CHANGED
|
@@ -8,13 +8,19 @@ class BaseRepository:
|
|
| 8 |
def __init__(self, model: Type[T]):
|
| 9 |
self.model = model
|
| 10 |
|
| 11 |
-
async def insert_one(self, data: T):
|
| 12 |
try:
|
| 13 |
return await data.insert(link_rule=WriteRules.WRITE)
|
| 14 |
except Exception as e:
|
| 15 |
raise ValueError(f"Failed to insert data: {str(e)}")
|
| 16 |
|
| 17 |
-
async def get_all(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
query = filter_by or {}
|
| 19 |
cursor = self.model.find(query, fetch_links=True)
|
| 20 |
|
|
@@ -27,7 +33,7 @@ class BaseRepository:
|
|
| 27 |
|
| 28 |
return await cursor.to_list(length=page_size if page_size else None)
|
| 29 |
|
| 30 |
-
async def get_by_id(self, _id: str):
|
| 31 |
return await self.model.get(_id, fetch_links=True)
|
| 32 |
|
| 33 |
async def update(self, _id: str, update_data: dict):
|
|
@@ -37,6 +43,7 @@ class BaseRepository:
|
|
| 37 |
|
| 38 |
updated_fields = {**document.model_dump(), **update_data.model_dump()}
|
| 39 |
return await document.set(updated_fields)
|
|
|
|
| 40 |
async def delete(self, _id: str):
|
| 41 |
|
| 42 |
document = await self.model.get(_id)
|
|
@@ -45,4 +52,4 @@ class BaseRepository:
|
|
| 45 |
return await document.delete(link_rule=DeleteRules.DELETE)
|
| 46 |
|
| 47 |
async def count_documents(self):
|
| 48 |
-
return await self.model.count()
|
|
|
|
| 8 |
def __init__(self, model: Type[T]):
|
| 9 |
self.model = model
|
| 10 |
|
| 11 |
+
async def insert_one(self, data: T) -> T:
|
| 12 |
try:
|
| 13 |
return await data.insert(link_rule=WriteRules.WRITE)
|
| 14 |
except Exception as e:
|
| 15 |
raise ValueError(f"Failed to insert data: {str(e)}")
|
| 16 |
|
| 17 |
+
async def get_all(
|
| 18 |
+
self,
|
| 19 |
+
page: int = 1,
|
| 20 |
+
page_size: int = None,
|
| 21 |
+
filter_by: Dict = None,
|
| 22 |
+
order_by: Dict = None,
|
| 23 |
+
):
|
| 24 |
query = filter_by or {}
|
| 25 |
cursor = self.model.find(query, fetch_links=True)
|
| 26 |
|
|
|
|
| 33 |
|
| 34 |
return await cursor.to_list(length=page_size if page_size else None)
|
| 35 |
|
| 36 |
+
async def get_by_id(self, _id: str) -> T:
|
| 37 |
return await self.model.get(_id, fetch_links=True)
|
| 38 |
|
| 39 |
async def update(self, _id: str, update_data: dict):
|
|
|
|
| 43 |
|
| 44 |
updated_fields = {**document.model_dump(), **update_data.model_dump()}
|
| 45 |
return await document.set(updated_fields)
|
| 46 |
+
|
| 47 |
async def delete(self, _id: str):
|
| 48 |
|
| 49 |
document = await self.model.get(_id)
|
|
|
|
| 52 |
return await document.delete(link_rule=DeleteRules.DELETE)
|
| 53 |
|
| 54 |
async def count_documents(self):
|
| 55 |
+
return await self.model.count()
|
src/repositories/_message_repository.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from src.models import Message
|
| 2 |
|
| 3 |
from ._base_repository import BaseRepository
|
| 4 |
|
|
@@ -6,3 +6,19 @@ from ._base_repository import BaseRepository
|
|
| 6 |
class MessageRepository(BaseRepository):
|
| 7 |
def __init__(self):
|
| 8 |
super().__init__(model=Message)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.models import Message, Conversation
|
| 2 |
|
| 3 |
from ._base_repository import BaseRepository
|
| 4 |
|
|
|
|
| 6 |
class MessageRepository(BaseRepository):
|
| 7 |
def __init__(self):
|
| 8 |
super().__init__(model=Message)
|
| 9 |
+
self.model = Message
|
| 10 |
+
|
| 11 |
+
async def get_messages_by_conversation(
|
| 12 |
+
self, conversation: Conversation
|
| 13 |
+
) -> list[Message]:
|
| 14 |
+
try:
|
| 15 |
+
messages = await self.model.find_many(
|
| 16 |
+
self.model.conversation.id == conversation.id, fetch_links=True
|
| 17 |
+
).to_list()
|
| 18 |
+
|
| 19 |
+
if messages:
|
| 20 |
+
return messages
|
| 21 |
+
else:
|
| 22 |
+
raise ValueError("Messages not found")
|
| 23 |
+
except Exception as e:
|
| 24 |
+
raise ValueError(f"Failed to get messages: {str(e)}")
|
src/services/_conversation_service.py
CHANGED
|
@@ -1,10 +1,9 @@
|
|
| 1 |
from fastapi import WebSocket
|
| 2 |
-
from bson import ObjectId
|
| 3 |
-
from beanie import WriteRules
|
| 4 |
|
| 5 |
from src.config import logger
|
| 6 |
from src.utils import OpenAIClient
|
| 7 |
from src.models import Conversation, User, Message
|
|
|
|
| 8 |
|
| 9 |
from ._websocket_service import WebSocketService
|
| 10 |
|
|
@@ -13,6 +12,9 @@ class ConversationService:
|
|
| 13 |
def __init__(self):
|
| 14 |
self.openai_client = OpenAIClient
|
| 15 |
self.websocket_service = WebSocketService
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
async def __aenter__(self):
|
| 18 |
return self
|
|
@@ -21,10 +23,11 @@ class ConversationService:
|
|
| 21 |
pass
|
| 22 |
|
| 23 |
async def create_conversation(self, user_id, modality):
|
| 24 |
-
user_object = await
|
| 25 |
-
conversation_object =
|
| 26 |
-
|
| 27 |
-
|
|
|
|
| 28 |
|
| 29 |
text_mode_only = True if modality == "text" else False
|
| 30 |
async with self.openai_client() as client:
|
|
@@ -33,7 +36,7 @@ class ConversationService:
|
|
| 33 |
)
|
| 34 |
|
| 35 |
return {
|
| 36 |
-
"conversation_id": str(
|
| 37 |
"session_data": session_data,
|
| 38 |
}
|
| 39 |
|
|
@@ -48,11 +51,14 @@ class ConversationService:
|
|
| 48 |
await ws_service.handle_conversation(websocket)
|
| 49 |
|
| 50 |
async def create_conversation_summary(self, conversation_id):
|
| 51 |
-
conversation_object =
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
conversation_history = "\n".join(
|
| 58 |
[f"{message.role}: {message.content}" for message in messages]
|
|
@@ -64,6 +70,7 @@ class ConversationService:
|
|
| 64 |
)
|
| 65 |
|
| 66 |
conversation_object.summary = conversation_summary
|
| 67 |
-
await
|
| 68 |
-
|
| 69 |
-
|
|
|
|
|
|
| 1 |
from fastapi import WebSocket
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from src.config import logger
|
| 4 |
from src.utils import OpenAIClient
|
| 5 |
from src.models import Conversation, User, Message
|
| 6 |
+
from src.repositories import ConversationRepository, MessageRepository, UserRepository
|
| 7 |
|
| 8 |
from ._websocket_service import WebSocketService
|
| 9 |
|
|
|
|
| 12 |
def __init__(self):
|
| 13 |
self.openai_client = OpenAIClient
|
| 14 |
self.websocket_service = WebSocketService
|
| 15 |
+
self.conversation_repository = ConversationRepository()
|
| 16 |
+
self.message_repository = MessageRepository()
|
| 17 |
+
self.user_repository = UserRepository()
|
| 18 |
|
| 19 |
async def __aenter__(self):
|
| 20 |
return self
|
|
|
|
| 23 |
pass
|
| 24 |
|
| 25 |
async def create_conversation(self, user_id, modality):
|
| 26 |
+
user_object = await self.user_repository.get_by_id(user_id)
|
| 27 |
+
conversation_object = await self.conversation_repository.insert_one(
|
| 28 |
+
Conversation(user=user_object, summary="")
|
| 29 |
+
)
|
| 30 |
+
print("Conversation ID: ", conversation_object.id)
|
| 31 |
|
| 32 |
text_mode_only = True if modality == "text" else False
|
| 33 |
async with self.openai_client() as client:
|
|
|
|
| 36 |
)
|
| 37 |
|
| 38 |
return {
|
| 39 |
+
"conversation_id": str(conversation_object.id),
|
| 40 |
"session_data": session_data,
|
| 41 |
}
|
| 42 |
|
|
|
|
| 51 |
await ws_service.handle_conversation(websocket)
|
| 52 |
|
| 53 |
async def create_conversation_summary(self, conversation_id):
|
| 54 |
+
conversation_object: Conversation = (
|
| 55 |
+
await self.conversation_repository.get_by_id(conversation_id)
|
| 56 |
+
)
|
| 57 |
+
messages: list[Message] = (
|
| 58 |
+
await self.message_repository.get_messages_by_conversation(
|
| 59 |
+
conversation_object
|
| 60 |
+
)
|
| 61 |
+
)
|
| 62 |
|
| 63 |
conversation_history = "\n".join(
|
| 64 |
[f"{message.role}: {message.content}" for message in messages]
|
|
|
|
| 70 |
)
|
| 71 |
|
| 72 |
conversation_object.summary = conversation_summary
|
| 73 |
+
updated_converation: Conversation = await self.conversation_repository.update(
|
| 74 |
+
conversation_object.id, conversation_object
|
| 75 |
+
)
|
| 76 |
+
return {"conversation_summary": updated_converation.summary}
|
src/services/_file_service.py
CHANGED
|
@@ -1,14 +1,18 @@
|
|
| 1 |
from fastapi import UploadFile
|
| 2 |
-
from beanie import WriteRules
|
| 3 |
|
| 4 |
-
from src.repositories import FileRepository, VectorStoreRecordIdRepository
|
| 5 |
-
from src.models import File, VectorStoreRecordId
|
| 6 |
from src.utils import PineconeClient
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
class FileService:
|
| 10 |
def __init__(self):
|
| 11 |
self.pinecone_client = PineconeClient
|
|
|
|
| 12 |
self.file_repository = FileRepository()
|
| 13 |
self.vector_store_record_id_repository = VectorStoreRecordIdRepository()
|
| 14 |
|
|
@@ -20,6 +24,8 @@ class FileService:
|
|
| 20 |
|
| 21 |
async def insert_file(self, file: UploadFile, user_id: str):
|
| 22 |
try:
|
|
|
|
|
|
|
| 23 |
metadata = [{"file_name": file.filename}]
|
| 24 |
while content := await file.read():
|
| 25 |
async with self.pinecone_client() as client:
|
|
@@ -27,18 +33,18 @@ class FileService:
|
|
| 27 |
text=[content.decode()], metadata=metadata
|
| 28 |
)
|
| 29 |
|
| 30 |
-
file_object =
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
|
| 34 |
)
|
| 35 |
-
await file_object.save(link_rule=WriteRules.WRITE)
|
| 36 |
|
| 37 |
for id in ids:
|
| 38 |
-
|
| 39 |
-
|
| 40 |
)
|
| 41 |
-
await vector_store_record_id_object.save(link_rule=WriteRules.WRITE)
|
| 42 |
return file_object
|
| 43 |
|
| 44 |
finally:
|
|
|
|
| 1 |
from fastapi import UploadFile
|
|
|
|
| 2 |
|
|
|
|
|
|
|
| 3 |
from src.utils import PineconeClient
|
| 4 |
+
from src.models import File, VectorStoreRecordId
|
| 5 |
+
from src.repositories import (
|
| 6 |
+
FileRepository,
|
| 7 |
+
VectorStoreRecordIdRepository,
|
| 8 |
+
UserRepository,
|
| 9 |
+
)
|
| 10 |
|
| 11 |
|
| 12 |
class FileService:
|
| 13 |
def __init__(self):
|
| 14 |
self.pinecone_client = PineconeClient
|
| 15 |
+
self.user_repository = UserRepository()
|
| 16 |
self.file_repository = FileRepository()
|
| 17 |
self.vector_store_record_id_repository = VectorStoreRecordIdRepository()
|
| 18 |
|
|
|
|
| 24 |
|
| 25 |
async def insert_file(self, file: UploadFile, user_id: str):
|
| 26 |
try:
|
| 27 |
+
user_object = await self.user_repository.get_by_id(user_id)
|
| 28 |
+
|
| 29 |
metadata = [{"file_name": file.filename}]
|
| 30 |
while content := await file.read():
|
| 31 |
async with self.pinecone_client() as client:
|
|
|
|
| 33 |
text=[content.decode()], metadata=metadata
|
| 34 |
)
|
| 35 |
|
| 36 |
+
file_object = await self.file_repository.insert_one(
|
| 37 |
+
File(
|
| 38 |
+
user=user_object,
|
| 39 |
+
file_name=file.filename,
|
| 40 |
+
file_type=file.content_type,
|
| 41 |
+
)
|
| 42 |
)
|
|
|
|
| 43 |
|
| 44 |
for id in ids:
|
| 45 |
+
await self.vector_store_record_id_repository.insert_one(
|
| 46 |
+
VectorStoreRecordId(file=file_object, record_id=id)
|
| 47 |
)
|
|
|
|
| 48 |
return file_object
|
| 49 |
|
| 50 |
finally:
|
src/services/_websocket_service.py
CHANGED
|
@@ -1,13 +1,12 @@
|
|
| 1 |
import json
|
| 2 |
from fastapi import WebSocket
|
| 3 |
-
from bson import ObjectId
|
| 4 |
-
from beanie import WriteRules
|
| 5 |
|
| 6 |
-
from ._file_service import FileService
|
| 7 |
from src.utils import OpenAIClient
|
| 8 |
-
from src.
|
| 9 |
-
from src.models import Conversation, Message
|
| 10 |
from src.config import logger
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
class WebSocketService:
|
|
@@ -203,23 +202,29 @@ class WebSocketService:
|
|
| 203 |
|
| 204 |
async def handle_user_message(self, message_content, conversation_id):
|
| 205 |
logger.info(f"User Query: {message_content}")
|
| 206 |
-
conversation_object = await
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
)
|
| 212 |
-
return await message_object.save(link_rule=WriteRules.WRITE)
|
| 213 |
|
| 214 |
async def handle_ai_message(self, message_content, conversation_id):
|
| 215 |
logger.info(f"AI Response: {message_content}")
|
| 216 |
-
conversation_object = await
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
)
|
| 222 |
-
return await message_object.save(link_rule=WriteRules.WRITE)
|
| 223 |
|
| 224 |
async def handle_ai_function_call(self, data):
|
| 225 |
tool_name = data["name"]
|
|
|
|
| 1 |
import json
|
| 2 |
from fastapi import WebSocket
|
|
|
|
|
|
|
| 3 |
|
|
|
|
| 4 |
from src.utils import OpenAIClient
|
| 5 |
+
from src.models import Message
|
|
|
|
| 6 |
from src.config import logger
|
| 7 |
+
from src.repositories import ConversationRepository, MessageRepository
|
| 8 |
+
|
| 9 |
+
from ._file_service import FileService
|
| 10 |
|
| 11 |
|
| 12 |
class WebSocketService:
|
|
|
|
| 202 |
|
| 203 |
async def handle_user_message(self, message_content, conversation_id):
|
| 204 |
logger.info(f"User Query: {message_content}")
|
| 205 |
+
conversation_object = await self.conversation_repository.get_by_id(
|
| 206 |
+
conversation_id
|
| 207 |
+
)
|
| 208 |
+
await self.message_repository.insert_one(
|
| 209 |
+
Message(
|
| 210 |
+
conversation=conversation_object,
|
| 211 |
+
role="user",
|
| 212 |
+
content=message_content,
|
| 213 |
+
)
|
| 214 |
)
|
|
|
|
| 215 |
|
| 216 |
async def handle_ai_message(self, message_content, conversation_id):
|
| 217 |
logger.info(f"AI Response: {message_content}")
|
| 218 |
+
conversation_object = await self.conversation_repository.get_by_id(
|
| 219 |
+
conversation_id
|
| 220 |
+
)
|
| 221 |
+
await self.message_repository.insert_one(
|
| 222 |
+
Message(
|
| 223 |
+
conversation=conversation_object,
|
| 224 |
+
role="assistant",
|
| 225 |
+
content=message_content,
|
| 226 |
+
)
|
| 227 |
)
|
|
|
|
| 228 |
|
| 229 |
async def handle_ai_function_call(self, data):
|
| 230 |
tool_name = data["name"]
|