feat: enhanced repository usages
Browse files- app/api/chat_api.py +3 -0
- app/config/db.py +0 -3
- app/config/{secret.py → security_config.py} +10 -6
- app/core/db_client.py +29 -25
- app/core/initial_setup/setup.py +40 -22
- app/db/client.py +3 -1
- app/db/embedded.py +1 -3
- app/db/factory.py +1 -3
- app/db/mongo.py +2 -1
- app/mapper/base_mapper.py +26 -0
- app/mapper/chat_mapper.py +47 -0
- app/model/chat_model.py +28 -51
- app/repository/chat_repository.py +97 -31
- app/schema/chat_schema.py +6 -18
- app/security/auth_service.py +19 -8
- app/service/chat_service.py +12 -16
- main.py +2 -2
app/api/chat_api.py
CHANGED
|
@@ -22,14 +22,17 @@ router = APIRouter(prefix="/v1", tags=["chat"])
|
|
| 22 |
service = ChatService()
|
| 23 |
auth_service = AuthService()
|
| 24 |
|
|
|
|
| 25 |
class VersionResponse(BaseModel):
|
| 26 |
version: str = "0.0.1"
|
| 27 |
|
|
|
|
| 28 |
# version api from pyproject.toml
|
| 29 |
@router.get("/version", response_model=VersionResponse)
|
| 30 |
async def get_version():
|
| 31 |
return VersionResponse()
|
| 32 |
|
|
|
|
| 33 |
################
|
| 34 |
# chat completion api list
|
| 35 |
################
|
|
|
|
| 22 |
service = ChatService()
|
| 23 |
auth_service = AuthService()
|
| 24 |
|
| 25 |
+
|
| 26 |
class VersionResponse(BaseModel):
|
| 27 |
version: str = "0.0.1"
|
| 28 |
|
| 29 |
+
|
| 30 |
# version api from pyproject.toml
|
| 31 |
@router.get("/version", response_model=VersionResponse)
|
| 32 |
async def get_version():
|
| 33 |
return VersionResponse()
|
| 34 |
|
| 35 |
+
|
| 36 |
################
|
| 37 |
# chat completion api list
|
| 38 |
################
|
app/config/db.py
CHANGED
|
@@ -10,7 +10,6 @@ class DBConfig(BaseSettings):
|
|
| 10 |
extra="ignore",
|
| 11 |
)
|
| 12 |
|
| 13 |
-
|
| 14 |
DATABASE_TYPE: Literal["mongodb", "embedded"] = "embedded"
|
| 15 |
DATABASE_NAME: str = "openai_chatbot_api"
|
| 16 |
|
|
@@ -19,8 +18,6 @@ class DBConfig(BaseSettings):
|
|
| 19 |
MONGO_HOST: str = "localhost"
|
| 20 |
MONGO_PORT: int = 27017
|
| 21 |
MONGO_URI: str = f"mongodb://{MONGO_USER}:{MONGO_PASSWORD}@{MONGO_HOST}:{MONGO_PORT}/{DATABASE_NAME}"
|
| 22 |
-
|
| 23 |
-
|
| 24 |
|
| 25 |
def get_mongo_uri(self) -> str:
|
| 26 |
return self.MONGO_URI
|
|
|
|
| 10 |
extra="ignore",
|
| 11 |
)
|
| 12 |
|
|
|
|
| 13 |
DATABASE_TYPE: Literal["mongodb", "embedded"] = "embedded"
|
| 14 |
DATABASE_NAME: str = "openai_chatbot_api"
|
| 15 |
|
|
|
|
| 18 |
MONGO_HOST: str = "localhost"
|
| 19 |
MONGO_PORT: int = 27017
|
| 20 |
MONGO_URI: str = f"mongodb://{MONGO_USER}:{MONGO_PASSWORD}@{MONGO_HOST}:{MONGO_PORT}/{DATABASE_NAME}"
|
|
|
|
|
|
|
| 21 |
|
| 22 |
def get_mongo_uri(self) -> str:
|
| 23 |
return self.MONGO_URI
|
app/config/{secret.py → security_config.py}
RENAMED
|
@@ -1,17 +1,21 @@
|
|
| 1 |
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
|
|
| 2 |
|
| 3 |
|
| 4 |
-
class
|
| 5 |
-
"""
|
| 6 |
|
| 7 |
model_config = SettingsConfigDict(
|
| 8 |
-
env_prefix="
|
| 9 |
env_file=".env",
|
| 10 |
env_file_encoding="utf-8",
|
| 11 |
extra="ignore",
|
| 12 |
)
|
| 13 |
|
| 14 |
-
|
|
|
|
|
|
|
| 15 |
|
| 16 |
-
|
| 17 |
-
|
|
|
|
|
|
| 1 |
from pydantic_settings import BaseSettings, SettingsConfigDict
|
| 2 |
+
from functools import lru_cache
|
| 3 |
|
| 4 |
|
| 5 |
+
class SecurityConfig(BaseSettings):
|
| 6 |
+
"""Security configuration to be set in env variables"""
|
| 7 |
|
| 8 |
model_config = SettingsConfigDict(
|
| 9 |
+
env_prefix="SECURITY_",
|
| 10 |
env_file=".env",
|
| 11 |
env_file_encoding="utf-8",
|
| 12 |
extra="ignore",
|
| 13 |
)
|
| 14 |
|
| 15 |
+
SECRET_KEY: str = "your-secret-key-here"
|
| 16 |
+
ENABLED: bool = True
|
| 17 |
+
DEFAULT_USERNAME: str = "admin"
|
| 18 |
|
| 19 |
+
@lru_cache()
|
| 20 |
+
def get_security_config() -> SecurityConfig:
|
| 21 |
+
return SecurityConfig()
|
app/core/db_client.py
CHANGED
|
@@ -1,110 +1,114 @@
|
|
| 1 |
from abc import ABC, abstractmethod
|
| 2 |
-
from typing import Optional,
|
| 3 |
from motor.motor_asyncio import AsyncIOMotorClient
|
| 4 |
from unittest.mock import AsyncMock
|
| 5 |
|
|
|
|
| 6 |
class DatabaseClient(ABC):
|
| 7 |
"""Abstract base class for database clients"""
|
| 8 |
-
|
| 9 |
@abstractmethod
|
| 10 |
async def connect(self) -> None:
|
| 11 |
"""Connect to the database"""
|
| 12 |
pass
|
| 13 |
-
|
| 14 |
@abstractmethod
|
| 15 |
async def disconnect(self) -> None:
|
| 16 |
"""Disconnect from the database"""
|
| 17 |
pass
|
| 18 |
-
|
| 19 |
@abstractmethod
|
| 20 |
async def get_database(self) -> Any:
|
| 21 |
"""Get database instance"""
|
| 22 |
pass
|
| 23 |
|
|
|
|
| 24 |
class MongoClient(DatabaseClient):
|
| 25 |
"""Real MongoDB client implementation"""
|
| 26 |
-
|
| 27 |
def __init__(self, connection_string: str):
|
| 28 |
self.connection_string = connection_string
|
| 29 |
self._client: Optional[AsyncIOMotorClient] = None
|
| 30 |
-
|
| 31 |
async def connect(self) -> None:
|
| 32 |
self._client = AsyncIOMotorClient(self.connection_string)
|
| 33 |
-
|
| 34 |
async def disconnect(self) -> None:
|
| 35 |
if self._client:
|
| 36 |
self._client.close()
|
| 37 |
-
|
| 38 |
async def get_database(self) -> AsyncIOMotorClient:
|
| 39 |
if not self._client:
|
| 40 |
await self.connect()
|
| 41 |
return self._client
|
| 42 |
|
|
|
|
| 43 |
class MockMongoClient(DatabaseClient):
|
| 44 |
"""Mock MongoDB client for testing"""
|
| 45 |
-
|
| 46 |
def __init__(self):
|
| 47 |
self._client = AsyncMock()
|
| 48 |
-
|
| 49 |
async def connect(self) -> None:
|
| 50 |
pass
|
| 51 |
-
|
| 52 |
async def disconnect(self) -> None:
|
| 53 |
pass
|
| 54 |
-
|
| 55 |
async def get_database(self) -> AsyncMock:
|
| 56 |
return self._client
|
| 57 |
|
|
|
|
| 58 |
class DatabaseClientFactory:
|
| 59 |
"""Factory for creating database clients"""
|
| 60 |
-
|
| 61 |
-
_instance: Optional[
|
| 62 |
_client: Optional[DatabaseClient] = None
|
| 63 |
-
|
| 64 |
def __new__(cls):
|
| 65 |
if cls._instance is None:
|
| 66 |
cls._instance = super().__new__(cls)
|
| 67 |
return cls._instance
|
| 68 |
-
|
| 69 |
@classmethod
|
| 70 |
def create_client(cls, db_type: str, connection_string: Optional[str] = None) -> DatabaseClient:
|
| 71 |
"""
|
| 72 |
Create a database client based on the database type
|
| 73 |
-
|
| 74 |
Args:
|
| 75 |
db_type: Type of database ('mongodb' or 'mock')
|
| 76 |
connection_string: Connection string for the database
|
| 77 |
-
|
| 78 |
Returns:
|
| 79 |
DatabaseClient: Instance of the appropriate database client
|
| 80 |
"""
|
| 81 |
if cls._client is None:
|
| 82 |
-
if db_type.lower() ==
|
| 83 |
if not connection_string:
|
| 84 |
raise ValueError("Connection string is required for MongoDB")
|
| 85 |
cls._client = MongoClient(connection_string)
|
| 86 |
-
elif db_type.lower() ==
|
| 87 |
cls._client = MockMongoClient()
|
| 88 |
else:
|
| 89 |
raise ValueError(f"Unsupported database type: {db_type}")
|
| 90 |
-
|
| 91 |
return cls._client
|
| 92 |
-
|
| 93 |
@classmethod
|
| 94 |
async def get_client(cls) -> DatabaseClient:
|
| 95 |
"""
|
| 96 |
Get the current database client instance
|
| 97 |
-
|
| 98 |
Returns:
|
| 99 |
DatabaseClient: Current database client instance
|
| 100 |
"""
|
| 101 |
if cls._client is None:
|
| 102 |
raise RuntimeError("Database client not initialized")
|
| 103 |
return cls._client
|
| 104 |
-
|
| 105 |
@classmethod
|
| 106 |
async def reset_client(cls) -> None:
|
| 107 |
"""Reset the current database client instance"""
|
| 108 |
if cls._client:
|
| 109 |
await cls._client.disconnect()
|
| 110 |
-
cls._client = None
|
|
|
|
| 1 |
from abc import ABC, abstractmethod
|
| 2 |
+
from typing import Optional, Any
|
| 3 |
from motor.motor_asyncio import AsyncIOMotorClient
|
| 4 |
from unittest.mock import AsyncMock
|
| 5 |
|
| 6 |
+
|
| 7 |
class DatabaseClient(ABC):
|
| 8 |
"""Abstract base class for database clients"""
|
| 9 |
+
|
| 10 |
@abstractmethod
|
| 11 |
async def connect(self) -> None:
|
| 12 |
"""Connect to the database"""
|
| 13 |
pass
|
| 14 |
+
|
| 15 |
@abstractmethod
|
| 16 |
async def disconnect(self) -> None:
|
| 17 |
"""Disconnect from the database"""
|
| 18 |
pass
|
| 19 |
+
|
| 20 |
@abstractmethod
|
| 21 |
async def get_database(self) -> Any:
|
| 22 |
"""Get database instance"""
|
| 23 |
pass
|
| 24 |
|
| 25 |
+
|
| 26 |
class MongoClient(DatabaseClient):
|
| 27 |
"""Real MongoDB client implementation"""
|
| 28 |
+
|
| 29 |
def __init__(self, connection_string: str):
|
| 30 |
self.connection_string = connection_string
|
| 31 |
self._client: Optional[AsyncIOMotorClient] = None
|
| 32 |
+
|
| 33 |
async def connect(self) -> None:
|
| 34 |
self._client = AsyncIOMotorClient(self.connection_string)
|
| 35 |
+
|
| 36 |
async def disconnect(self) -> None:
|
| 37 |
if self._client:
|
| 38 |
self._client.close()
|
| 39 |
+
|
| 40 |
async def get_database(self) -> AsyncIOMotorClient:
|
| 41 |
if not self._client:
|
| 42 |
await self.connect()
|
| 43 |
return self._client
|
| 44 |
|
| 45 |
+
|
| 46 |
class MockMongoClient(DatabaseClient):
|
| 47 |
"""Mock MongoDB client for testing"""
|
| 48 |
+
|
| 49 |
def __init__(self):
|
| 50 |
self._client = AsyncMock()
|
| 51 |
+
|
| 52 |
async def connect(self) -> None:
|
| 53 |
pass
|
| 54 |
+
|
| 55 |
async def disconnect(self) -> None:
|
| 56 |
pass
|
| 57 |
+
|
| 58 |
async def get_database(self) -> AsyncMock:
|
| 59 |
return self._client
|
| 60 |
|
| 61 |
+
|
| 62 |
class DatabaseClientFactory:
|
| 63 |
"""Factory for creating database clients"""
|
| 64 |
+
|
| 65 |
+
_instance: Optional["DatabaseClientFactory"] = None
|
| 66 |
_client: Optional[DatabaseClient] = None
|
| 67 |
+
|
| 68 |
def __new__(cls):
|
| 69 |
if cls._instance is None:
|
| 70 |
cls._instance = super().__new__(cls)
|
| 71 |
return cls._instance
|
| 72 |
+
|
| 73 |
@classmethod
|
| 74 |
def create_client(cls, db_type: str, connection_string: Optional[str] = None) -> DatabaseClient:
|
| 75 |
"""
|
| 76 |
Create a database client based on the database type
|
| 77 |
+
|
| 78 |
Args:
|
| 79 |
db_type: Type of database ('mongodb' or 'mock')
|
| 80 |
connection_string: Connection string for the database
|
| 81 |
+
|
| 82 |
Returns:
|
| 83 |
DatabaseClient: Instance of the appropriate database client
|
| 84 |
"""
|
| 85 |
if cls._client is None:
|
| 86 |
+
if db_type.lower() == "mongodb":
|
| 87 |
if not connection_string:
|
| 88 |
raise ValueError("Connection string is required for MongoDB")
|
| 89 |
cls._client = MongoClient(connection_string)
|
| 90 |
+
elif db_type.lower() == "mock":
|
| 91 |
cls._client = MockMongoClient()
|
| 92 |
else:
|
| 93 |
raise ValueError(f"Unsupported database type: {db_type}")
|
| 94 |
+
|
| 95 |
return cls._client
|
| 96 |
+
|
| 97 |
@classmethod
|
| 98 |
async def get_client(cls) -> DatabaseClient:
|
| 99 |
"""
|
| 100 |
Get the current database client instance
|
| 101 |
+
|
| 102 |
Returns:
|
| 103 |
DatabaseClient: Current database client instance
|
| 104 |
"""
|
| 105 |
if cls._client is None:
|
| 106 |
raise RuntimeError("Database client not initialized")
|
| 107 |
return cls._client
|
| 108 |
+
|
| 109 |
@classmethod
|
| 110 |
async def reset_client(cls) -> None:
|
| 111 |
"""Reset the current database client instance"""
|
| 112 |
if cls._client:
|
| 113 |
await cls._client.disconnect()
|
| 114 |
+
cls._client = None
|
app/core/initial_setup/setup.py
CHANGED
|
@@ -6,20 +6,21 @@ from app.model.chat_model import ChatCompletion
|
|
| 6 |
from app.repository.chat_repository import ChatRepository
|
| 7 |
from app.config.db import db_config
|
| 8 |
|
|
|
|
| 9 |
class InitialSetup:
|
| 10 |
"""Initial setup manager for the application when database type is embedded"""
|
| 11 |
-
|
| 12 |
def __init__(self):
|
| 13 |
self._chat_repository: Optional[ChatRepository] = None
|
| 14 |
self.data_dir = os.path.join(os.path.dirname(__file__), "data")
|
| 15 |
-
|
| 16 |
@property
|
| 17 |
def chat_repository(self) -> ChatRepository:
|
| 18 |
"""Lazy loading of ChatRepository"""
|
| 19 |
if self._chat_repository is None:
|
| 20 |
self._chat_repository = ChatRepository()
|
| 21 |
return self._chat_repository
|
| 22 |
-
|
| 23 |
def _load_initial_data(self) -> List[ChatCompletion]:
|
| 24 |
"""Load initial data from JSON files"""
|
| 25 |
try:
|
|
@@ -29,30 +30,47 @@ class InitialSetup:
|
|
| 29 |
except Exception as e:
|
| 30 |
logger.error(f"Error loading initial data: {e}")
|
| 31 |
return []
|
| 32 |
-
|
| 33 |
async def setup(self) -> None:
|
| 34 |
"""Setup initial data if database type is embedded"""
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
# query of the saved chat completions
|
| 51 |
saved_chat_completions = await self.chat_repository.find()
|
| 52 |
logger.debug("********** Begin of Saved chat completions**********")
|
| 53 |
logger.trace(f"{saved_chat_completions}")
|
| 54 |
-
|
|
|
|
|
|
|
| 55 |
logger.debug("********** End of Saved chat completions**********")
|
| 56 |
|
| 57 |
-
|
| 58 |
-
logger.info("Initial setup completed successfully for embedded database")
|
|
|
|
| 6 |
from app.repository.chat_repository import ChatRepository
|
| 7 |
from app.config.db import db_config
|
| 8 |
|
| 9 |
+
|
| 10 |
class InitialSetup:
|
| 11 |
"""Initial setup manager for the application when database type is embedded"""
|
| 12 |
+
|
| 13 |
def __init__(self):
|
| 14 |
self._chat_repository: Optional[ChatRepository] = None
|
| 15 |
self.data_dir = os.path.join(os.path.dirname(__file__), "data")
|
| 16 |
+
|
| 17 |
@property
|
| 18 |
def chat_repository(self) -> ChatRepository:
|
| 19 |
"""Lazy loading of ChatRepository"""
|
| 20 |
if self._chat_repository is None:
|
| 21 |
self._chat_repository = ChatRepository()
|
| 22 |
return self._chat_repository
|
| 23 |
+
|
| 24 |
def _load_initial_data(self) -> List[ChatCompletion]:
|
| 25 |
"""Load initial data from JSON files"""
|
| 26 |
try:
|
|
|
|
| 30 |
except Exception as e:
|
| 31 |
logger.error(f"Error loading initial data: {e}")
|
| 32 |
return []
|
| 33 |
+
|
| 34 |
async def setup(self) -> None:
|
| 35 |
"""Setup initial data if database type is embedded"""
|
| 36 |
+
try:
|
| 37 |
+
# if db_config.DATABASE_TYPE != "embedded":
|
| 38 |
+
# logger.info("Skipping initial setup as database type is not embedded")
|
| 39 |
+
# return
|
| 40 |
+
|
| 41 |
+
# delete all chat completions in the embedded database
|
| 42 |
+
logger.warning("Deleting all chat completions in the embedded database")
|
| 43 |
+
await self.chat_repository.db.chat_completion.delete_many({})
|
| 44 |
+
logger.warning("Deleting all chat completions in the embedded database done")
|
| 45 |
+
|
| 46 |
+
chat_completions = self._load_initial_data()
|
| 47 |
+
logger.info(f"Loaded {len(chat_completions)} initial chat completions")
|
| 48 |
+
|
| 49 |
+
for completion in chat_completions:
|
| 50 |
+
try:
|
| 51 |
+
found_id = await self.chat_repository.find_by_id(completion.completion_id)
|
| 52 |
+
if found_id:
|
| 53 |
+
logger.debug(f"Chat completion already exists: {found_id}")
|
| 54 |
+
continue
|
| 55 |
+
|
| 56 |
+
saved = await self.chat_repository.save(completion)
|
| 57 |
+
logger.info(f"Successfully saved chat completion: {saved.completion_id}")
|
| 58 |
+
|
| 59 |
+
except Exception as e:
|
| 60 |
+
logger.error(f"Error saving chat completion {completion.completion_id}: {e}")
|
| 61 |
+
raise
|
| 62 |
+
|
| 63 |
+
except Exception as e:
|
| 64 |
+
logger.error(f"Setup failed: {e}")
|
| 65 |
+
raise
|
| 66 |
+
|
| 67 |
# query of the saved chat completions
|
| 68 |
saved_chat_completions = await self.chat_repository.find()
|
| 69 |
logger.debug("********** Begin of Saved chat completions**********")
|
| 70 |
logger.trace(f"{saved_chat_completions}")
|
| 71 |
+
chat = saved_chat_completions[0]
|
| 72 |
+
chat.messages = []
|
| 73 |
+
logger.debug(f"saved_chat_completions[0] : {chat}")
|
| 74 |
logger.debug("********** End of Saved chat completions**********")
|
| 75 |
|
| 76 |
+
logger.info("Initial setup completed successfully for embedded database")
|
|
|
app/db/client.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
| 1 |
# mongodb client with motor and pymongo
|
| 2 |
|
| 3 |
from abc import ABC, abstractmethod
|
| 4 |
-
from typing import
|
|
|
|
| 5 |
|
| 6 |
class DatabaseClientProtocol(Protocol):
|
| 7 |
"""
|
|
@@ -19,6 +20,7 @@ class DatabaseClientProtocol(Protocol):
|
|
| 19 |
- https://www.python.org/dev/peps/pep-0544/
|
| 20 |
- https://realpython.com/duck-typing-python/
|
| 21 |
"""
|
|
|
|
| 22 |
async def connect(self) -> None: ...
|
| 23 |
async def close(self) -> None: ...
|
| 24 |
@property
|
|
|
|
| 1 |
# mongodb client with motor and pymongo
|
| 2 |
|
| 3 |
from abc import ABC, abstractmethod
|
| 4 |
+
from typing import Protocol
|
| 5 |
+
|
| 6 |
|
| 7 |
class DatabaseClientProtocol(Protocol):
|
| 8 |
"""
|
|
|
|
| 20 |
- https://www.python.org/dev/peps/pep-0544/
|
| 21 |
- https://realpython.com/duck-typing-python/
|
| 22 |
"""
|
| 23 |
+
|
| 24 |
async def connect(self) -> None: ...
|
| 25 |
async def close(self) -> None: ...
|
| 26 |
@property
|
app/db/embedded.py
CHANGED
|
@@ -1,12 +1,10 @@
|
|
| 1 |
-
from typing import Optional
|
| 2 |
from mongomock_motor import AsyncMongoMockClient
|
| 3 |
from app.config.db import db_config
|
| 4 |
from loguru import logger
|
| 5 |
from app.db.client import DatabaseClient
|
| 6 |
|
| 7 |
|
| 8 |
-
|
| 9 |
-
|
| 10 |
class EmbeddedMongoClient(DatabaseClient):
|
| 11 |
"""
|
| 12 |
Mock MongoDB client implementation for local machine and development environment.
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
from mongomock_motor import AsyncMongoMockClient
|
| 3 |
from app.config.db import db_config
|
| 4 |
from loguru import logger
|
| 5 |
from app.db.client import DatabaseClient
|
| 6 |
|
| 7 |
|
|
|
|
|
|
|
| 8 |
class EmbeddedMongoClient(DatabaseClient):
|
| 9 |
"""
|
| 10 |
Mock MongoDB client implementation for local machine and development environment.
|
app/db/factory.py
CHANGED
|
@@ -5,6 +5,7 @@ from app.db.mongo import PersistentMongoClient
|
|
| 5 |
from app.db.embedded import EmbeddedMongoClient
|
| 6 |
from app.config.db import db_config
|
| 7 |
|
|
|
|
| 8 |
class DatabaseClientFactory:
|
| 9 |
"""Factory class for creating database clients"""
|
| 10 |
|
|
@@ -33,9 +34,6 @@ class DatabaseClientFactory:
|
|
| 33 |
cls._client = EmbeddedMongoClient()
|
| 34 |
logger.info(f"Returning DatabaseClientFactory.client. Host: {cls._client.client.host}")
|
| 35 |
return cls._client
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
|
| 40 |
|
| 41 |
# Global instance
|
|
|
|
| 5 |
from app.db.embedded import EmbeddedMongoClient
|
| 6 |
from app.config.db import db_config
|
| 7 |
|
| 8 |
+
|
| 9 |
class DatabaseClientFactory:
|
| 10 |
"""Factory class for creating database clients"""
|
| 11 |
|
|
|
|
| 34 |
cls._client = EmbeddedMongoClient()
|
| 35 |
logger.info(f"Returning DatabaseClientFactory.client. Host: {cls._client.client.host}")
|
| 36 |
return cls._client
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
|
| 39 |
# Global instance
|
app/db/mongo.py
CHANGED
|
@@ -4,6 +4,7 @@ from app.config.db import db_config
|
|
| 4 |
from loguru import logger
|
| 5 |
from app.db.client import DatabaseClient
|
| 6 |
|
|
|
|
| 7 |
class PersistentMongoClient(DatabaseClient):
|
| 8 |
"""Real MongoDB client implementation"""
|
| 9 |
|
|
@@ -57,4 +58,4 @@ class PersistentMongoClient(DatabaseClient):
|
|
| 57 |
logger.warning(f"Error while closing MongoDB connection: {e}")
|
| 58 |
self._client = None
|
| 59 |
self._db = None
|
| 60 |
-
self._is_connected = False
|
|
|
|
| 4 |
from loguru import logger
|
| 5 |
from app.db.client import DatabaseClient
|
| 6 |
|
| 7 |
+
|
| 8 |
class PersistentMongoClient(DatabaseClient):
|
| 9 |
"""Real MongoDB client implementation"""
|
| 10 |
|
|
|
|
| 58 |
logger.warning(f"Error while closing MongoDB connection: {e}")
|
| 59 |
self._client = None
|
| 60 |
self._db = None
|
| 61 |
+
self._is_connected = False
|
app/mapper/base_mapper.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from typing import TypeVar, Generic, List
|
| 3 |
+
|
| 4 |
+
T = TypeVar('T')
|
| 5 |
+
U = TypeVar('U')
|
| 6 |
+
|
| 7 |
+
class BaseMapper(Generic[T, U], ABC):
|
| 8 |
+
"""Base mapper class for mapping between model and schema objects."""
|
| 9 |
+
|
| 10 |
+
@abstractmethod
|
| 11 |
+
def to_schema(self, model: T) -> U:
|
| 12 |
+
"""Map from model to schema."""
|
| 13 |
+
pass
|
| 14 |
+
|
| 15 |
+
@abstractmethod
|
| 16 |
+
def to_model(self, schema: U) -> T:
|
| 17 |
+
"""Map from schema to model."""
|
| 18 |
+
pass
|
| 19 |
+
|
| 20 |
+
def to_schema_list(self, models: List[T]) -> List[U]:
|
| 21 |
+
"""Map a list of models to schemas."""
|
| 22 |
+
return [self.to_schema(model) for model in models]
|
| 23 |
+
|
| 24 |
+
def to_model_list(self, schemas: List[U]) -> List[T]:
|
| 25 |
+
"""Map a list of schemas to models."""
|
| 26 |
+
return [self.to_model(schema) for schema in schemas]
|
app/mapper/chat_mapper.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import datetime
|
| 2 |
+
from typing import Optional
|
| 3 |
+
from app.mapper.base_mapper import BaseMapper
|
| 4 |
+
from app.model.chat_model import ChatCompletion, ChatMessage
|
| 5 |
+
from app.schema.chat_schema import ChatCompletionResponse, ChatCompletionRequest, MessageResponse, ChoiceResponse
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ChatMapper(BaseMapper[ChatCompletion, ChatCompletionResponse]):
|
| 9 |
+
"""Mapper for converting between ChatCompletion model and schema objects."""
|
| 10 |
+
|
| 11 |
+
def to_schema(self, model: ChatCompletion) -> ChatCompletionResponse:
|
| 12 |
+
"""Convert ChatCompletion model to ChatCompletionResponse schema."""
|
| 13 |
+
# Convert datetime to Unix timestamp
|
| 14 |
+
created_timestamp = int(model.created_date.timestamp()) if model.created_date else None
|
| 15 |
+
|
| 16 |
+
# Map the last message to a response message
|
| 17 |
+
last_message = model.messages[-1] if model.messages else None
|
| 18 |
+
message_response = None
|
| 19 |
+
if last_message:
|
| 20 |
+
message_response = MessageResponse(
|
| 21 |
+
message_id=last_message.message_id,
|
| 22 |
+
role=last_message.role,
|
| 23 |
+
content=last_message.content,
|
| 24 |
+
figure=last_message.figure,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
# Create choice response
|
| 28 |
+
choice = ChoiceResponse(index=0, message=message_response, finish_reason="stop") if message_response else None
|
| 29 |
+
|
| 30 |
+
return ChatCompletionResponse(
|
| 31 |
+
completion_id=model.completion_id, model=model.model, created=created_timestamp, choices=[choice] if choice else []
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
def to_model(self, schema: ChatCompletionRequest) -> ChatCompletion:
|
| 35 |
+
"""Convert ChatCompletionRequest schema to ChatCompletion model."""
|
| 36 |
+
# Convert messages to ChatMessage objects
|
| 37 |
+
messages = []
|
| 38 |
+
if schema.messages:
|
| 39 |
+
for msg in schema.messages:
|
| 40 |
+
messages.append(ChatMessage(role=msg.role, content=msg.content))
|
| 41 |
+
|
| 42 |
+
return ChatCompletion(
|
| 43 |
+
completion_id=schema.completion_id,
|
| 44 |
+
model=schema.model or "gpt-4o", # Default model
|
| 45 |
+
messages=messages,
|
| 46 |
+
stream=schema.stream or False,
|
| 47 |
+
)
|
app/model/chat_model.py
CHANGED
|
@@ -1,12 +1,9 @@
|
|
| 1 |
# chat model for chat completion database
|
| 2 |
|
| 3 |
import json
|
| 4 |
-
import
|
| 5 |
-
from enum import Enum
|
| 6 |
-
from pydantic import BaseModel
|
| 7 |
from datetime import datetime
|
| 8 |
-
from
|
| 9 |
-
from typing import List, Optional
|
| 10 |
|
| 11 |
# Chat completion payload example
|
| 12 |
# {
|
|
@@ -54,70 +51,44 @@ from typing import List, Optional
|
|
| 54 |
# }
|
| 55 |
|
| 56 |
|
| 57 |
-
class MessageRole(str, Enum):
|
| 58 |
-
USER = "user"
|
| 59 |
-
ASSISTANT = "assistant"
|
| 60 |
-
SYSTEM = "system"
|
| 61 |
-
|
| 62 |
-
|
| 63 |
class ChatMessage(BaseModel):
|
| 64 |
"""
|
| 65 |
A message in a chat completion.
|
| 66 |
"""
|
| 67 |
|
| 68 |
-
message_id: str = Field(
|
| 69 |
-
|
| 70 |
-
description="The unique identifier for the message",
|
| 71 |
-
default_factory=lambda: str(uuid.uuid4()),
|
| 72 |
-
)
|
| 73 |
-
role: MessageRole = Field(
|
| 74 |
-
...,
|
| 75 |
-
description="The role of the message sender",
|
| 76 |
-
examples=[
|
| 77 |
-
MessageRole.USER,
|
| 78 |
-
MessageRole.ASSISTANT,
|
| 79 |
-
MessageRole.SYSTEM,
|
| 80 |
-
],
|
| 81 |
-
)
|
| 82 |
content: str = Field(..., description="The content of the message")
|
| 83 |
-
figure: Optional[
|
| 84 |
-
created_date: datetime = Field(
|
| 85 |
-
default_factory=datetime.now,
|
| 86 |
-
description="The timestamp of the message",
|
| 87 |
-
)
|
| 88 |
|
| 89 |
-
# write a formattet and graceful to string method
|
| 90 |
def __str__(self):
|
| 91 |
return f"""
|
| 92 |
ChatMessage(
|
| 93 |
message_id={self.message_id},
|
| 94 |
role={self.role},
|
| 95 |
content={self.content},
|
| 96 |
-
figure={json.dumps(self.figure, indent=4)},
|
| 97 |
created_date={self.created_date})
|
| 98 |
"""
|
| 99 |
-
|
| 100 |
def __repr__(self):
|
| 101 |
return self.__str__()
|
| 102 |
-
|
| 103 |
def __format__(self, format_spec):
|
| 104 |
return self.__str__()
|
| 105 |
|
| 106 |
|
| 107 |
-
|
| 108 |
class ChatCompletion(BaseModel):
|
| 109 |
"""
|
| 110 |
A chat completion.
|
| 111 |
"""
|
| 112 |
|
|
|
|
| 113 |
completion_id: Optional[str] = Field(None, description="The unique identifier for the chat completion")
|
| 114 |
|
| 115 |
# openai compatible fields
|
| 116 |
-
model: str = Field(
|
| 117 |
-
...,
|
| 118 |
-
description="The model used for the chat completion",
|
| 119 |
-
examples=["gpt-4o-mini", "gpt-4o", "gpt-3.5-turbo"],
|
| 120 |
-
)
|
| 121 |
messages: List[ChatMessage] = Field(..., description="The messages in the chat completion")
|
| 122 |
|
| 123 |
# not implemented yet
|
|
@@ -128,27 +99,36 @@ class ChatCompletion(BaseModel):
|
|
| 128 |
# presence_penalty: float = Field(default=0.0, ge=-2.0, le=2.0, description="Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.")
|
| 129 |
# n: int = Field(default=1, ge=1, le=10, description="How many chat completion choices to generate for each prompt.")
|
| 130 |
|
| 131 |
-
stream: bool = Field(
|
| 132 |
-
|
| 133 |
-
description="If set to true, the model response data will be streamed to the client as it is generated using
|
| 134 |
)
|
| 135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
# audit fields
|
| 137 |
-
created_by: str = Field(
|
| 138 |
created_date: datetime = Field(
|
| 139 |
default_factory=datetime.now,
|
| 140 |
description="The date and time the chat completion was created",
|
| 141 |
)
|
| 142 |
-
last_updated_by: str = Field(
|
| 143 |
-
last_updated_date: datetime = Field(
|
| 144 |
default_factory=datetime.now,
|
| 145 |
description="The date and time the chat completion was last updated",
|
| 146 |
)
|
| 147 |
|
| 148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
def __str__(self):
|
| 150 |
return f"""
|
| 151 |
-
ChatCompletion(
|
| 152 |
completion_id={self.completion_id},
|
| 153 |
model={self.model},
|
| 154 |
messages={self.messages},
|
|
@@ -158,11 +138,8 @@ class ChatCompletion(BaseModel):
|
|
| 158 |
last_updated_date={self.last_updated_date})
|
| 159 |
"""
|
| 160 |
|
| 161 |
-
|
| 162 |
def __repr__(self):
|
| 163 |
return self.__str__()
|
| 164 |
|
| 165 |
-
|
| 166 |
def __format__(self, format_spec):
|
| 167 |
return self.__str__()
|
| 168 |
-
|
|
|
|
| 1 |
# chat model for chat completion database
|
| 2 |
|
| 3 |
import json
|
| 4 |
+
from pydantic import BaseModel, Field
|
|
|
|
|
|
|
| 5 |
from datetime import datetime
|
| 6 |
+
from typing import List, Optional, Dict, Any
|
|
|
|
| 7 |
|
| 8 |
# Chat completion payload example
|
| 9 |
# {
|
|
|
|
| 51 |
# }
|
| 52 |
|
| 53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
class ChatMessage(BaseModel):
|
| 55 |
"""
|
| 56 |
A message in a chat completion.
|
| 57 |
"""
|
| 58 |
|
| 59 |
+
message_id: str = Field(..., description="The unique identifier for the message")
|
| 60 |
+
role: Optional[str] = Field(None, description="The role of the message sender", examples=["user", "assistant", "system"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
content: str = Field(..., description="The content of the message")
|
| 62 |
+
figure: Optional[Dict[str, Any]] = Field(None, description="The figure data for visualization")
|
| 63 |
+
created_date: datetime = Field(default_factory=datetime.now, description="The timestamp of the message")
|
|
|
|
|
|
|
|
|
|
| 64 |
|
|
|
|
| 65 |
def __str__(self):
|
| 66 |
return f"""
|
| 67 |
ChatMessage(
|
| 68 |
message_id={self.message_id},
|
| 69 |
role={self.role},
|
| 70 |
content={self.content},
|
| 71 |
+
figure={json.dumps(self.figure, indent=4) if self.figure else None},
|
| 72 |
created_date={self.created_date})
|
| 73 |
"""
|
| 74 |
+
|
| 75 |
def __repr__(self):
|
| 76 |
return self.__str__()
|
| 77 |
+
|
| 78 |
def __format__(self, format_spec):
|
| 79 |
return self.__str__()
|
| 80 |
|
| 81 |
|
|
|
|
| 82 |
class ChatCompletion(BaseModel):
|
| 83 |
"""
|
| 84 |
A chat completion.
|
| 85 |
"""
|
| 86 |
|
| 87 |
+
#id: Optional[str] = Field(None, alias="_id", description="MongoDB document ID")
|
| 88 |
completion_id: Optional[str] = Field(None, description="The unique identifier for the chat completion")
|
| 89 |
|
| 90 |
# openai compatible fields
|
| 91 |
+
model: Optional[str] = Field(None, description="The model used for the chat completion", examples=["gpt-4o-mini", "gpt-4o", "gpt-3.5-turbo"])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
messages: List[ChatMessage] = Field(..., description="The messages in the chat completion")
|
| 93 |
|
| 94 |
# not implemented yet
|
|
|
|
| 99 |
# presence_penalty: float = Field(default=0.0, ge=-2.0, le=2.0, description="Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.")
|
| 100 |
# n: int = Field(default=1, ge=1, le=10, description="How many chat completion choices to generate for each prompt.")
|
| 101 |
|
| 102 |
+
stream: Optional[bool] = Field(
|
| 103 |
+
None,
|
| 104 |
+
description="If set to true, the model response data will be streamed to the client as it is generated using server-sent events.",
|
| 105 |
)
|
| 106 |
|
| 107 |
+
title: Optional[str] = Field(None, description="The title of the chat completion")
|
| 108 |
+
object_field: Optional[str] = Field(None, alias="object_field", description="The object field of the chat completion")
|
| 109 |
+
is_archived: Optional[bool] = Field(None, description="Whether the chat completion is archived")
|
| 110 |
+
is_starred: Optional[bool] = Field(None, description="Whether the chat completion is starred")
|
| 111 |
+
|
| 112 |
# audit fields
|
| 113 |
+
created_by: Optional[str] = Field(None, description="The user who created the chat completion")
|
| 114 |
created_date: datetime = Field(
|
| 115 |
default_factory=datetime.now,
|
| 116 |
description="The date and time the chat completion was created",
|
| 117 |
)
|
| 118 |
+
last_updated_by: Optional[str] = Field(None, description="The user who last updated the chat completion")
|
| 119 |
+
last_updated_date: Optional[datetime] = Field(
|
| 120 |
default_factory=datetime.now,
|
| 121 |
description="The date and time the chat completion was last updated",
|
| 122 |
)
|
| 123 |
|
| 124 |
+
|
| 125 |
+
class Config:
|
| 126 |
+
populate_by_name = True
|
| 127 |
+
arbitrary_types_allowed = True
|
| 128 |
+
|
| 129 |
def __str__(self):
|
| 130 |
return f"""
|
| 131 |
+
ChatCompletion(
|
| 132 |
completion_id={self.completion_id},
|
| 133 |
model={self.model},
|
| 134 |
messages={self.messages},
|
|
|
|
| 138 |
last_updated_date={self.last_updated_date})
|
| 139 |
"""
|
| 140 |
|
|
|
|
| 141 |
def __repr__(self):
|
| 142 |
return self.__str__()
|
| 143 |
|
|
|
|
| 144 |
def __format__(self, format_spec):
|
| 145 |
return self.__str__()
|
|
|
app/repository/chat_repository.py
CHANGED
|
@@ -1,8 +1,10 @@
|
|
|
|
|
| 1 |
from typing import List
|
| 2 |
from app.db.factory import db_client
|
| 3 |
from app.model.chat_model import ChatMessage, ChatCompletion
|
| 4 |
from loguru import logger
|
| 5 |
import uuid
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
class ChatRepository:
|
|
@@ -15,61 +17,125 @@ class ChatRepository:
|
|
| 15 |
"""
|
| 16 |
Upsert a chat completion into the database. If the chat completion already exists, it will be updated. If it does not exist, it will be created.
|
| 17 |
"""
|
| 18 |
-
logger.debug(f"BEGIN: save chat completion.
|
|
|
|
| 19 |
if entity.completion_id is None:
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
else:
|
| 22 |
logger.debug(f"completion_id is not None. Using the existing one: {entity.completion_id}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
upserted_entity = await self.db.chat_completion.update_one(query, update, upsert=True)
|
| 28 |
-
logger.debug(f"upserted_entity: {upserted_entity}")
|
| 29 |
|
| 30 |
-
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
# save conversation if new chat completion
|
| 34 |
# TODO: save conversation
|
| 35 |
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
"""
|
| 42 |
Find a chat completion by a given query. with pagination
|
| 43 |
Example : query = {"created_by": "admin"}
|
| 44 |
"""
|
| 45 |
-
logger.debug(f"BEGIN: find chat completion.
|
| 46 |
skip = (page - 1) * limit
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
logger.
|
| 52 |
-
|
| 53 |
-
logger.debug("END: find chat completion")
|
| 54 |
-
return result
|
| 55 |
|
| 56 |
-
async def find_by_id(self, completion_id: str,
|
| 57 |
"""
|
| 58 |
Find a chat completion by a given id.
|
| 59 |
Example : completion_id = "123"
|
| 60 |
"""
|
| 61 |
-
logger.debug(f"BEGIN: find chat completion by id. input parameters: completion_id: {completion_id},
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
async def find_messages(self, completion_id: str) -> List[ChatMessage]:
|
| 68 |
"""
|
| 69 |
Find all messages for a given chat completion id.
|
| 70 |
Example : completion_id = "123"
|
| 71 |
"""
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datetime
|
| 2 |
from typing import List
|
| 3 |
from app.db.factory import db_client
|
| 4 |
from app.model.chat_model import ChatMessage, ChatCompletion
|
| 5 |
from loguru import logger
|
| 6 |
import uuid
|
| 7 |
+
import pymongo
|
| 8 |
|
| 9 |
|
| 10 |
class ChatRepository:
|
|
|
|
| 17 |
"""
|
| 18 |
Upsert a chat completion into the database. If the chat completion already exists, it will be updated. If it does not exist, it will be created.
|
| 19 |
"""
|
| 20 |
+
logger.debug(f"BEGIN REPO: save chat completion. username: {entity.created_by}, completion_id: {entity.completion_id}")
|
| 21 |
+
entity_dict = entity.model_dump(by_alias=True)
|
| 22 |
if entity.completion_id is None:
|
| 23 |
+
generated_completion_id = str(uuid.uuid4())
|
| 24 |
+
logger.warning(f"completion_id was None. Generated a new one: {generated_completion_id}")
|
| 25 |
+
entity.completion_id = generated_completion_id
|
| 26 |
+
entity_dict["completion_id"] = generated_completion_id
|
| 27 |
+
|
| 28 |
+
query = {"completion_id": generated_completion_id}
|
| 29 |
+
update = {
|
| 30 |
+
"$set": entity_dict,
|
| 31 |
+
"$setOnInsert": {
|
| 32 |
+
"created_date": entity.created_date.isoformat()
|
| 33 |
+
if entity.created_date
|
| 34 |
+
else datetime.datetime.now(datetime.timezone.utc).isoformat(),
|
| 35 |
+
"created_by": entity.created_by,
|
| 36 |
+
},
|
| 37 |
+
}
|
| 38 |
else:
|
| 39 |
logger.debug(f"completion_id is not None. Using the existing one: {entity.completion_id}")
|
| 40 |
+
query = {"completion_id": entity.completion_id}
|
| 41 |
+
update_payload = entity_dict.copy()
|
| 42 |
+
update_payload.pop("created_date", None) # created_date is not updatable
|
| 43 |
+
update_payload.pop("created_by", None) # created_by is not updatable
|
| 44 |
+
update_payload.pop("completion_id", None) # completion_id is not updatable
|
| 45 |
+
update = {"$set": update_payload}
|
| 46 |
|
| 47 |
+
upsert_result = await self.db.chat_completion.update_one(query, update, upsert=True)
|
| 48 |
+
logger.debug(f"upserted_entity. _id: {upsert_result.upserted_id}")
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
+
if upsert_result.upserted_id:
|
| 51 |
+
logger.info(f"Inserted new chat completion with ID: {entity.completion_id} (upserted_id: {upsert_result.upserted_id})")
|
| 52 |
+
elif upsert_result.modified_count > 0:
|
| 53 |
+
logger.info(f"Updated existing chat completion with ID: {entity.completion_id}")
|
| 54 |
+
elif upsert_result.matched_count > 0:
|
| 55 |
+
logger.info(f"Chat completion with ID: {entity.completion_id} matched but not modified.")
|
| 56 |
+
else:
|
| 57 |
+
logger.warning(
|
| 58 |
+
f"Chat completion with ID: {entity.completion_id} - No operation performed (no match, no upsert, no modification). This might be unexpected."
|
| 59 |
+
)
|
| 60 |
|
| 61 |
# save conversation if new chat completion
|
| 62 |
# TODO: save conversation
|
| 63 |
|
| 64 |
+
# after upsert, find the final db entity. if is there any trigger, default, etc. return the latest one
|
| 65 |
+
final_entity = await self.find_by_id(entity.completion_id)
|
| 66 |
+
if not final_entity:
|
| 67 |
+
# This should not happen, upsert should be successful.
|
| 68 |
+
logger.error(f"CRITICAL: Failed to retrieve chat completion {entity.completion_id} immediately after save operation.")
|
| 69 |
+
raise Exception(f"Data integrity issue: Could not find chat completion {entity.completion_id} after saving.")
|
| 70 |
|
| 71 |
+
logger.debug(f"END REPO: save chat completion, returning final_entity_id: {final_entity.completion_id}")
|
| 72 |
+
return final_entity
|
| 73 |
+
|
| 74 |
+
async def find(
|
| 75 |
+
self, query: dict = {}, page: int = 1, limit: int = 10, sort: dict = {"created_date": -1}, projection: dict = None
|
| 76 |
+
) -> List[ChatCompletion]:
|
| 77 |
"""
|
| 78 |
Find a chat completion by a given query. with pagination
|
| 79 |
Example : query = {"created_by": "admin"}
|
| 80 |
"""
|
| 81 |
+
logger.debug(f"BEGIN REPO: find chat completion. query: {query}, page: {page}, limit: {limit}, sort: {sort}, projection: {projection}")
|
| 82 |
skip = (page - 1) * limit
|
| 83 |
+
sort_query = sort if sort else [("created_date", pymongo.DESCENDING)]
|
| 84 |
+
|
| 85 |
+
cursor = self.db.chat_completion.find(query, projection).skip(skip).limit(limit).sort(sort_query)
|
| 86 |
+
db_docs = await cursor.to_list(length=limit)
|
| 87 |
+
result_models = []
|
| 88 |
+
for item in db_docs:
|
| 89 |
+
try:
|
| 90 |
+
result_models.append(ChatCompletion(**item))
|
| 91 |
+
except Exception as e:
|
| 92 |
+
logger.error(f"Error parsing ChatCompletion from DB for item with id {item.get('_id', 'N/A')}: {e}", exc_info=True)
|
| 93 |
+
# TODO: handle error
|
| 94 |
|
| 95 |
+
logger.trace(f"REPO find result (raw): {db_docs}")
|
| 96 |
+
logger.trace(f"REPO find result (models): {result_models}")
|
| 97 |
+
logger.debug(f"END REPO: find, returning {len(result_models)} models.")
|
| 98 |
+
return result_models
|
|
|
|
|
|
|
| 99 |
|
| 100 |
+
async def find_by_id(self, completion_id: str, projection: dict = None) -> ChatCompletion:
|
| 101 |
"""
|
| 102 |
Find a chat completion by a given id.
|
| 103 |
Example : completion_id = "123"
|
| 104 |
"""
|
| 105 |
+
logger.debug(f"BEGIN REPO: find chat completion by id. input parameters: completion_id: {completion_id}, projection: {projection}")
|
| 106 |
+
|
| 107 |
+
entity_doc = await self.db.chat_completion.find_one({"completion_id": completion_id}, projection)
|
| 108 |
+
|
| 109 |
+
if entity_doc:
|
| 110 |
+
logger.trace(f"REPO find_by_id. Found entity_doc: {entity_doc}")
|
| 111 |
+
try:
|
| 112 |
+
final_entity = ChatCompletion(**entity_doc)
|
| 113 |
+
logger.debug(f"END REPO: find_by_id. Found: {final_entity.completion_id}")
|
| 114 |
+
return final_entity
|
| 115 |
+
except Exception as e:
|
| 116 |
+
logger.error(f"Error parsing ChatCompletion from DB for id {completion_id}: {e}", exc_info=True)
|
| 117 |
+
return None # Parse hatası durumunda None döndür
|
| 118 |
+
else:
|
| 119 |
+
logger.info(f"Chat completion with ID {completion_id} not found in DB.")
|
| 120 |
+
return None
|
| 121 |
|
| 122 |
async def find_messages(self, completion_id: str) -> List[ChatMessage]:
|
| 123 |
"""
|
| 124 |
Find all messages for a given chat completion id.
|
| 125 |
Example : completion_id = "123"
|
| 126 |
"""
|
| 127 |
+
logger.debug(f"BEGIN REPO: find messages for chat completion id. input parameters: completion_id: {completion_id}")
|
| 128 |
+
projection = {"messages": 1, "_id": 0}
|
| 129 |
+
chat_doc = await self.db.chat_completion.find_one({"completion_id": completion_id}, projection)
|
| 130 |
+
|
| 131 |
+
if chat_doc and "messages" in chat_doc and chat_doc["messages"]:
|
| 132 |
+
try:
|
| 133 |
+
messages_list = [ChatMessage(**item) for item in chat_doc["messages"]]
|
| 134 |
+
logger.debug(f"END REPO: find_messages. Found {len(messages_list)} messages.")
|
| 135 |
+
return messages_list
|
| 136 |
+
except Exception as e:
|
| 137 |
+
logger.error(f"Error parsing messages for completion_id {completion_id}: {e}", exc_info=True)
|
| 138 |
+
return []
|
| 139 |
+
|
| 140 |
+
logger.info(f"No messages found for completion_id {completion_id} or messages field is empty/missing.")
|
| 141 |
+
return []
|
app/schema/chat_schema.py
CHANGED
|
@@ -1,18 +1,5 @@
|
|
| 1 |
from typing import List, Optional
|
| 2 |
from pydantic import BaseModel, Field
|
| 3 |
-
from enum import Enum
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
class ChoiceSchemaFinishReasonType(str, Enum):
|
| 7 |
-
STOP = "stop"
|
| 8 |
-
LENGTH = "length"
|
| 9 |
-
CONTENT_FILTER = "content_filter"
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
class MessageSchemaRoleType(str, Enum):
|
| 13 |
-
USER = "user"
|
| 14 |
-
ASSISTANT = "assistant"
|
| 15 |
-
SYSTEM = "system"
|
| 16 |
|
| 17 |
|
| 18 |
class MessageRequest(BaseModel):
|
|
@@ -20,15 +7,15 @@ class MessageRequest(BaseModel):
|
|
| 20 |
Represents a message in a chat completion.
|
| 21 |
"""
|
| 22 |
|
| 23 |
-
role:
|
| 24 |
-
content: str = Field(
|
| 25 |
-
|
| 26 |
|
| 27 |
|
| 28 |
class ChatCompletionRequest(BaseModel):
|
| 29 |
"""
|
| 30 |
Represents a chat completion request. Starting a new chat or continuing a previous chat.
|
| 31 |
"""
|
|
|
|
| 32 |
completion_id: Optional[str] = Field(
|
| 33 |
None,
|
| 34 |
description="The unique identifier for the chat completion. When starting a new chat, this will be a new UUID. When continuing a previous chat, this will be the same as the previous chat completion id.",
|
|
@@ -44,15 +31,16 @@ class MessageResponse(BaseModel):
|
|
| 44 |
"""
|
| 45 |
|
| 46 |
message_id: Optional[str] = Field(None, description="The unique identifier for the message")
|
| 47 |
-
role: Optional[
|
| 48 |
content: Optional[str] = Field(None, description="The content of the message")
|
| 49 |
figure: Optional[dict] = Field(None, description="The figure data to be visualized")
|
| 50 |
|
| 51 |
|
| 52 |
class ChoiceResponse(BaseModel):
|
| 53 |
-
finish_reason: Optional[
|
| 54 |
None,
|
| 55 |
description="The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence, `length` if the maximum number of tokens specified in the request was reached, `content_filter` if content was omitted due to a flag from our content filters",
|
|
|
|
| 56 |
)
|
| 57 |
index: Optional[int] = Field(None, description="The index of the choice in the list of choices.")
|
| 58 |
message: Optional[MessageResponse] = Field(None, description="The message to use for the chat completion")
|
|
|
|
| 1 |
from typing import List, Optional
|
| 2 |
from pydantic import BaseModel, Field
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
|
| 5 |
class MessageRequest(BaseModel):
|
|
|
|
| 7 |
Represents a message in a chat completion.
|
| 8 |
"""
|
| 9 |
|
| 10 |
+
role: Optional[str] = Field(None, description="The role of the message", examples=["user", "assistant", "system"])
|
| 11 |
+
content: Optional[str] = Field(None, description="The content of the message")
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
class ChatCompletionRequest(BaseModel):
|
| 15 |
"""
|
| 16 |
Represents a chat completion request. Starting a new chat or continuing a previous chat.
|
| 17 |
"""
|
| 18 |
+
|
| 19 |
completion_id: Optional[str] = Field(
|
| 20 |
None,
|
| 21 |
description="The unique identifier for the chat completion. When starting a new chat, this will be a new UUID. When continuing a previous chat, this will be the same as the previous chat completion id.",
|
|
|
|
| 31 |
"""
|
| 32 |
|
| 33 |
message_id: Optional[str] = Field(None, description="The unique identifier for the message")
|
| 34 |
+
role: Optional[str] = Field(None, description="The role of the message", examples=["user", "assistant", "system"])
|
| 35 |
content: Optional[str] = Field(None, description="The content of the message")
|
| 36 |
figure: Optional[dict] = Field(None, description="The figure data to be visualized")
|
| 37 |
|
| 38 |
|
| 39 |
class ChoiceResponse(BaseModel):
|
| 40 |
+
finish_reason: Optional[str] = Field(
|
| 41 |
None,
|
| 42 |
description="The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence, `length` if the maximum number of tokens specified in the request was reached, `content_filter` if content was omitted due to a flag from our content filters",
|
| 43 |
+
examples=["stop", "length", "content_filter"],
|
| 44 |
)
|
| 45 |
index: Optional[int] = Field(None, description="The index of the choice in the list of choices.")
|
| 46 |
message: Optional[MessageResponse] = Field(None, description="The message to use for the chat completion")
|
app/security/auth_service.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
-
from app.config.
|
| 2 |
-
from fastapi import HTTPException, status, Security
|
| 3 |
from fastapi.security import APIKeyHeader
|
| 4 |
from loguru import logger
|
| 5 |
import base64
|
|
@@ -12,13 +12,14 @@ api_key_header = APIKeyHeader(
|
|
| 12 |
name="Authorization",
|
| 13 |
scheme_name="ApiKeyAuth",
|
| 14 |
description="API key in the format: sk-{username}-{base64_encoded_data}",
|
|
|
|
| 15 |
)
|
| 16 |
|
| 17 |
|
| 18 |
class AuthService:
|
| 19 |
-
def __init__(self):
|
| 20 |
-
self.secret = secret
|
| 21 |
self.api_key_header = api_key_header
|
|
|
|
| 22 |
|
| 23 |
def decode_api_key(self, api_key: str) -> str:
|
| 24 |
"""Decode API key to extract username and verify signature."""
|
|
@@ -58,10 +59,10 @@ class AuthService:
|
|
| 58 |
}
|
| 59 |
json_str = json.dumps(json_data)
|
| 60 |
logger.trace(f"JSON data for signature: {json_str}")
|
| 61 |
-
logger.trace(f"Secret key: {self.
|
| 62 |
|
| 63 |
expected_signature = hmac.new(
|
| 64 |
-
self.
|
| 65 |
json_str.encode(),
|
| 66 |
hashlib.sha256,
|
| 67 |
).hexdigest()
|
|
@@ -86,11 +87,21 @@ class AuthService:
|
|
| 86 |
detail=f"Invalid API key: {str(e)}",
|
| 87 |
)
|
| 88 |
|
| 89 |
-
def verify_credentials(self, api_key: str =
|
| 90 |
"""Verify API key and extract username."""
|
| 91 |
logger.trace(f"BEGIN: api_key: {api_key}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
username = self.decode_api_key(api_key)
|
| 93 |
-
|
| 94 |
result = username
|
| 95 |
logger.trace(f"END: result: {result}")
|
| 96 |
return result
|
|
|
|
| 1 |
+
from app.config.security_config import get_security_config
|
| 2 |
+
from fastapi import HTTPException, status, Security, Depends
|
| 3 |
from fastapi.security import APIKeyHeader
|
| 4 |
from loguru import logger
|
| 5 |
import base64
|
|
|
|
| 12 |
name="Authorization",
|
| 13 |
scheme_name="ApiKeyAuth",
|
| 14 |
description="API key in the format: sk-{username}-{base64_encoded_data}",
|
| 15 |
+
auto_error=False # API key olmadığında otomatik hata vermesini engelle
|
| 16 |
)
|
| 17 |
|
| 18 |
|
| 19 |
class AuthService:
|
| 20 |
+
def __init__(self):
|
|
|
|
| 21 |
self.api_key_header = api_key_header
|
| 22 |
+
self.security_config = get_security_config()
|
| 23 |
|
| 24 |
def decode_api_key(self, api_key: str) -> str:
|
| 25 |
"""Decode API key to extract username and verify signature."""
|
|
|
|
| 59 |
}
|
| 60 |
json_str = json.dumps(json_data)
|
| 61 |
logger.trace(f"JSON data for signature: {json_str}")
|
| 62 |
+
logger.trace(f"Secret key: {self.security_config.SECRET_KEY}")
|
| 63 |
|
| 64 |
expected_signature = hmac.new(
|
| 65 |
+
self.security_config.SECRET_KEY.encode(),
|
| 66 |
json_str.encode(),
|
| 67 |
hashlib.sha256,
|
| 68 |
).hexdigest()
|
|
|
|
| 87 |
detail=f"Invalid API key: {str(e)}",
|
| 88 |
)
|
| 89 |
|
| 90 |
+
async def verify_credentials(self, api_key: str = Security(api_key_header)) -> str:
|
| 91 |
"""Verify API key and extract username."""
|
| 92 |
logger.trace(f"BEGIN: api_key: {api_key}")
|
| 93 |
+
|
| 94 |
+
if not self.security_config.ENABLED:
|
| 95 |
+
logger.info("Security is disabled, using default username: " + self.security_config.DEFAULT_USERNAME)
|
| 96 |
+
return self.security_config.DEFAULT_USERNAME
|
| 97 |
+
|
| 98 |
+
if not api_key:
|
| 99 |
+
raise HTTPException(
|
| 100 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 101 |
+
detail="API key is required when security is enabled",
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
username = self.decode_api_key(api_key)
|
|
|
|
| 105 |
result = username
|
| 106 |
logger.trace(f"END: result: {result}")
|
| 107 |
return result
|
app/service/chat_service.py
CHANGED
|
@@ -8,6 +8,7 @@ from app.schema.chat_schema import (
|
|
| 8 |
MessageResponse,
|
| 9 |
)
|
| 10 |
from app.model.chat_model import ChatCompletion, ChatMessage
|
|
|
|
| 11 |
import uuid
|
| 12 |
from loguru import logger
|
| 13 |
from app.schema.conversation import ConversationResponse
|
|
@@ -16,13 +17,16 @@ from app.schema.conversation import ConversationResponse
|
|
| 16 |
class ChatService:
|
| 17 |
def __init__(self):
|
| 18 |
self.chat_repository = ChatRepository()
|
|
|
|
| 19 |
|
| 20 |
async def handle_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
| 21 |
last_user_message = request.messages[-1].content
|
| 22 |
response_content = f"TODO implement ai-agent response for this message: {last_user_message}"
|
| 23 |
username = "admin"
|
| 24 |
|
| 25 |
-
|
|
|
|
|
|
|
| 26 |
if entity.completion_id:
|
| 27 |
entity.completion_id = str(uuid.uuid4())
|
| 28 |
entity.created_by = username
|
|
@@ -31,28 +35,20 @@ class ChatService:
|
|
| 31 |
entity.last_updated_by = username
|
| 32 |
entity.last_updated_date = datetime.datetime.now()
|
| 33 |
|
|
|
|
| 34 |
entity = await self.chat_repository.save(entity)
|
| 35 |
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
result.choices = [
|
| 39 |
-
ChoiceResponse(
|
| 40 |
-
**{
|
| 41 |
-
"index": 0,
|
| 42 |
-
"message": messages[0],
|
| 43 |
-
"finish_reason": "stop",
|
| 44 |
-
}
|
| 45 |
-
)
|
| 46 |
-
]
|
| 47 |
-
return result
|
| 48 |
|
| 49 |
async def find(self, query: dict, page: int, limit: int, sort: dict, project: dict = None) -> List[ChatCompletionResponse]:
|
| 50 |
logger.debug(f"BEGIN SERVICE: find for query: {query}, page: {page}, limit: {limit}, sort: {sort}, project: {project}")
|
| 51 |
entities = await self.chat_repository.find(query, page, limit, sort, project)
|
| 52 |
-
return
|
| 53 |
|
| 54 |
-
async def find_by_id(self, completion_id: str, project: dict = None) ->
|
| 55 |
-
|
|
|
|
| 56 |
|
| 57 |
async def find_messages(self, completion_id: str) -> List[ChatMessage]:
|
| 58 |
return await self.chat_repository.find_messages(completion_id)
|
|
|
|
| 8 |
MessageResponse,
|
| 9 |
)
|
| 10 |
from app.model.chat_model import ChatCompletion, ChatMessage
|
| 11 |
+
from app.mapper.chat_mapper import ChatMapper
|
| 12 |
import uuid
|
| 13 |
from loguru import logger
|
| 14 |
from app.schema.conversation import ConversationResponse
|
|
|
|
| 17 |
class ChatService:
|
| 18 |
def __init__(self):
|
| 19 |
self.chat_repository = ChatRepository()
|
| 20 |
+
self.chat_mapper = ChatMapper()
|
| 21 |
|
| 22 |
async def handle_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
| 23 |
last_user_message = request.messages[-1].content
|
| 24 |
response_content = f"TODO implement ai-agent response for this message: {last_user_message}"
|
| 25 |
username = "admin"
|
| 26 |
|
| 27 |
+
# Convert request to model
|
| 28 |
+
entity = self.chat_mapper.to_model(request)
|
| 29 |
+
|
| 30 |
if entity.completion_id:
|
| 31 |
entity.completion_id = str(uuid.uuid4())
|
| 32 |
entity.created_by = username
|
|
|
|
| 35 |
entity.last_updated_by = username
|
| 36 |
entity.last_updated_date = datetime.datetime.now()
|
| 37 |
|
| 38 |
+
# Save to database
|
| 39 |
entity = await self.chat_repository.save(entity)
|
| 40 |
|
| 41 |
+
# Convert model to response
|
| 42 |
+
return self.chat_mapper.to_schema(entity)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
async def find(self, query: dict, page: int, limit: int, sort: dict, project: dict = None) -> List[ChatCompletionResponse]:
|
| 45 |
logger.debug(f"BEGIN SERVICE: find for query: {query}, page: {page}, limit: {limit}, sort: {sort}, project: {project}")
|
| 46 |
entities = await self.chat_repository.find(query, page, limit, sort, project)
|
| 47 |
+
return self.chat_mapper.to_schema_list(entities)
|
| 48 |
|
| 49 |
+
async def find_by_id(self, completion_id: str, project: dict = None) -> ChatCompletionResponse:
|
| 50 |
+
entity = await self.chat_repository.find_by_id(completion_id, project)
|
| 51 |
+
return self.chat_mapper.to_schema(entity) if entity else None
|
| 52 |
|
| 53 |
async def find_messages(self, completion_id: str) -> List[ChatMessage]:
|
| 54 |
return await self.chat_repository.find_messages(completion_id)
|
main.py
CHANGED
|
@@ -35,11 +35,11 @@ async def lifespan(app: FastAPI):
|
|
| 35 |
# Startup
|
| 36 |
logger.info("Starting up application...")
|
| 37 |
await db_client.connect()
|
| 38 |
-
|
| 39 |
# Run initial setup if database type is embedded
|
| 40 |
initial_setup = InitialSetup()
|
| 41 |
await initial_setup.setup()
|
| 42 |
-
|
| 43 |
yield
|
| 44 |
|
| 45 |
# Shutdown
|
|
|
|
| 35 |
# Startup
|
| 36 |
logger.info("Starting up application...")
|
| 37 |
await db_client.connect()
|
| 38 |
+
|
| 39 |
# Run initial setup if database type is embedded
|
| 40 |
initial_setup = InitialSetup()
|
| 41 |
await initial_setup.setup()
|
| 42 |
+
|
| 43 |
yield
|
| 44 |
|
| 45 |
# Shutdown
|