style: format with ruff
Browse files- app/api/chat_api.py +4 -2
- app/mapper/chat_mapper.py +0 -2
- app/model/chat_model.py +1 -4
- app/repository/chat_repository.py +16 -14
- app/schema/chat_schema.py +1 -3
- app/security/auth_service.py +1 -1
- app/service/chat_service.py +8 -2
app/api/chat_api.py
CHANGED
|
@@ -3,7 +3,7 @@
|
|
| 3 |
from typing import Any, List, Optional
|
| 4 |
from fastapi import APIRouter, HTTPException, Depends, Request
|
| 5 |
from pydantic import BaseModel
|
| 6 |
-
from app.schema.chat_schema import ChatCompletionRequest, ChatCompletionResponse, MessageResponse
|
| 7 |
from app.schema.conversation_schema import ConversationResponse, ConversationItemResponse
|
| 8 |
from app.service.chat_service import ChatService
|
| 9 |
from app.security.auth_service import AuthService
|
|
@@ -96,7 +96,9 @@ async def list_messages(completion_id: str, request: Request, username: str = De
|
|
| 96 |
# plot api list
|
| 97 |
################
|
| 98 |
# get a plot for a message
|
| 99 |
-
@router.get(
|
|
|
|
|
|
|
| 100 |
async def retrieve_plot(completion_id: str, message_id: str, request: Request, username: str = Depends(auth_service.verify_credentials)):
|
| 101 |
"""
|
| 102 |
Get a plot figure for a message to visualize the data
|
|
|
|
| 3 |
from typing import Any, List, Optional
|
| 4 |
from fastapi import APIRouter, HTTPException, Depends, Request
|
| 5 |
from pydantic import BaseModel
|
| 6 |
+
from app.schema.chat_schema import ChatCompletionRequest, ChatCompletionResponse, MessageResponse
|
| 7 |
from app.schema.conversation_schema import ConversationResponse, ConversationItemResponse
|
| 8 |
from app.service.chat_service import ChatService
|
| 9 |
from app.security.auth_service import AuthService
|
|
|
|
| 96 |
# plot api list
|
| 97 |
################
|
| 98 |
# get a plot for a message
|
| 99 |
+
@router.get(
|
| 100 |
+
"/chat/completions/{completion_id}/messages/{message_id}/plot", response_model=Optional[dict[str, Any]], response_model_exclude_none=True
|
| 101 |
+
)
|
| 102 |
async def retrieve_plot(completion_id: str, message_id: str, request: Request, username: str = Depends(auth_service.verify_credentials)):
|
| 103 |
"""
|
| 104 |
Get a plot figure for a message to visualize the data
|
app/mapper/chat_mapper.py
CHANGED
|
@@ -1,5 +1,3 @@
|
|
| 1 |
-
from datetime import datetime
|
| 2 |
-
from typing import Optional
|
| 3 |
from app.mapper.base_mapper import BaseMapper
|
| 4 |
from app.model.chat_model import ChatCompletion, ChatMessage
|
| 5 |
from app.schema.chat_schema import ChatCompletionResponse, ChatCompletionRequest, MessageResponse, ChoiceResponse
|
|
|
|
|
|
|
|
|
|
| 1 |
from app.mapper.base_mapper import BaseMapper
|
| 2 |
from app.model.chat_model import ChatCompletion, ChatMessage
|
| 3 |
from app.schema.chat_schema import ChatCompletionResponse, ChatCompletionRequest, MessageResponse, ChoiceResponse
|
app/model/chat_model.py
CHANGED
|
@@ -1,10 +1,9 @@
|
|
| 1 |
# chat model for chat completion database
|
| 2 |
|
| 3 |
-
import json
|
| 4 |
from bson import ObjectId
|
| 5 |
from pydantic import BaseModel, Field
|
| 6 |
from datetime import datetime
|
| 7 |
-
from typing import List, Optional,
|
| 8 |
|
| 9 |
# Chat completion payload example
|
| 10 |
# {
|
|
@@ -51,8 +50,6 @@ from typing import List, Optional, Dict, Any
|
|
| 51 |
# "stream": false
|
| 52 |
# }
|
| 53 |
|
| 54 |
-
|
| 55 |
-
|
| 56 |
|
| 57 |
class ChatMessage(BaseModel):
|
| 58 |
"""
|
|
|
|
| 1 |
# chat model for chat completion database
|
| 2 |
|
|
|
|
| 3 |
from bson import ObjectId
|
| 4 |
from pydantic import BaseModel, Field
|
| 5 |
from datetime import datetime
|
| 6 |
+
from typing import List, Optional, Any
|
| 7 |
|
| 8 |
# Chat completion payload example
|
| 9 |
# {
|
|
|
|
| 50 |
# "stream": false
|
| 51 |
# }
|
| 52 |
|
|
|
|
|
|
|
| 53 |
|
| 54 |
class ChatMessage(BaseModel):
|
| 55 |
"""
|
app/repository/chat_repository.py
CHANGED
|
@@ -181,14 +181,16 @@ class ChatRepository:
|
|
| 181 |
logger.trace(f"REPO find_messages. chat_doc: {chat_doc}")
|
| 182 |
if chat_doc and "messages" in chat_doc and chat_doc["messages"]:
|
| 183 |
try:
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
|
|
|
|
|
|
| 192 |
logger.debug(f"END REPO: find_messages. Found {len(messages_list)} messages.")
|
| 193 |
return messages_list
|
| 194 |
except Exception as e:
|
|
@@ -197,14 +199,14 @@ class ChatRepository:
|
|
| 197 |
|
| 198 |
logger.info(f"No messages found for completion_id {completion_id} or messages field is empty/missing.")
|
| 199 |
return []
|
| 200 |
-
|
| 201 |
async def find_plot_by_message(self, completion_id: str, message_id: str) -> Optional[dict[str, Any]]:
|
| 202 |
"""
|
| 203 |
Find a plot by a given message id.
|
| 204 |
Example : completion_id = "123", message_id = "123"
|
| 205 |
"""
|
| 206 |
logger.debug(f"BEGIN REPO: find plot by message id. input parameters: completion_id: {completion_id}, message_id: {message_id}")
|
| 207 |
-
|
| 208 |
query = {"completion_id": completion_id}
|
| 209 |
projection = {"messages": 1, "_id": 0}
|
| 210 |
try:
|
|
@@ -213,7 +215,7 @@ class ChatRepository:
|
|
| 213 |
except Exception as e:
|
| 214 |
logger.error(f"Error finding plot by message id: {e}")
|
| 215 |
return None
|
| 216 |
-
|
| 217 |
# Mesajları Python tarafında filtreleyelim
|
| 218 |
if entity_doc and "messages" in entity_doc and entity_doc["messages"]:
|
| 219 |
try:
|
|
@@ -223,7 +225,7 @@ class ChatRepository:
|
|
| 223 |
# figure = message.get("figure")
|
| 224 |
# logger.debug(f"REPO find figure: {figure}")
|
| 225 |
# return figure
|
| 226 |
-
|
| 227 |
match = next((message for message in entity_doc["messages"] if message["message_id"] == message_id), None)
|
| 228 |
if match:
|
| 229 |
figure = match.get("figure")
|
|
@@ -231,7 +233,7 @@ class ChatRepository:
|
|
| 231 |
return figure
|
| 232 |
else:
|
| 233 |
logger.warning(f"Message with ID {message_id} not found")
|
| 234 |
-
|
| 235 |
logger.warning(f"Message with ID {message_id} not found")
|
| 236 |
return None
|
| 237 |
except Exception as e:
|
|
@@ -239,4 +241,4 @@ class ChatRepository:
|
|
| 239 |
return None
|
| 240 |
else:
|
| 241 |
logger.warning(f"REPO find_plot_by_message. entity_doc: {entity_doc}")
|
| 242 |
-
return None
|
|
|
|
| 181 |
logger.trace(f"REPO find_messages. chat_doc: {chat_doc}")
|
| 182 |
if chat_doc and "messages" in chat_doc and chat_doc["messages"]:
|
| 183 |
try:
|
| 184 |
+
messages_list = [
|
| 185 |
+
ChatMessage(
|
| 186 |
+
message_id=item["message_id"],
|
| 187 |
+
role=item["role"],
|
| 188 |
+
content=item["content"],
|
| 189 |
+
figure=item["figure"] if item["figure"] else None,
|
| 190 |
+
created_date=item["created_date"],
|
| 191 |
+
)
|
| 192 |
+
for item in chat_doc["messages"]
|
| 193 |
+
]
|
| 194 |
logger.debug(f"END REPO: find_messages. Found {len(messages_list)} messages.")
|
| 195 |
return messages_list
|
| 196 |
except Exception as e:
|
|
|
|
| 199 |
|
| 200 |
logger.info(f"No messages found for completion_id {completion_id} or messages field is empty/missing.")
|
| 201 |
return []
|
| 202 |
+
|
| 203 |
async def find_plot_by_message(self, completion_id: str, message_id: str) -> Optional[dict[str, Any]]:
|
| 204 |
"""
|
| 205 |
Find a plot by a given message id.
|
| 206 |
Example : completion_id = "123", message_id = "123"
|
| 207 |
"""
|
| 208 |
logger.debug(f"BEGIN REPO: find plot by message id. input parameters: completion_id: {completion_id}, message_id: {message_id}")
|
| 209 |
+
|
| 210 |
query = {"completion_id": completion_id}
|
| 211 |
projection = {"messages": 1, "_id": 0}
|
| 212 |
try:
|
|
|
|
| 215 |
except Exception as e:
|
| 216 |
logger.error(f"Error finding plot by message id: {e}")
|
| 217 |
return None
|
| 218 |
+
|
| 219 |
# Mesajları Python tarafında filtreleyelim
|
| 220 |
if entity_doc and "messages" in entity_doc and entity_doc["messages"]:
|
| 221 |
try:
|
|
|
|
| 225 |
# figure = message.get("figure")
|
| 226 |
# logger.debug(f"REPO find figure: {figure}")
|
| 227 |
# return figure
|
| 228 |
+
|
| 229 |
match = next((message for message in entity_doc["messages"] if message["message_id"] == message_id), None)
|
| 230 |
if match:
|
| 231 |
figure = match.get("figure")
|
|
|
|
| 233 |
return figure
|
| 234 |
else:
|
| 235 |
logger.warning(f"Message with ID {message_id} not found")
|
| 236 |
+
|
| 237 |
logger.warning(f"Message with ID {message_id} not found")
|
| 238 |
return None
|
| 239 |
except Exception as e:
|
|
|
|
| 241 |
return None
|
| 242 |
else:
|
| 243 |
logger.warning(f"REPO find_plot_by_message. entity_doc: {entity_doc}")
|
| 244 |
+
return None
|
app/schema/chat_schema.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from typing import Any,
|
| 2 |
from pydantic import BaseModel, Field
|
| 3 |
from datetime import datetime
|
| 4 |
|
|
@@ -26,8 +26,6 @@ class ChatCompletionRequest(BaseModel):
|
|
| 26 |
stream: Optional[bool] = Field(None, description="Whether to stream the chat completion")
|
| 27 |
|
| 28 |
|
| 29 |
-
|
| 30 |
-
|
| 31 |
class MessageResponse(BaseModel):
|
| 32 |
"""
|
| 33 |
A chat completion message generated by the model.
|
|
|
|
| 1 |
+
from typing import Any, List, Optional
|
| 2 |
from pydantic import BaseModel, Field
|
| 3 |
from datetime import datetime
|
| 4 |
|
|
|
|
| 26 |
stream: Optional[bool] = Field(None, description="Whether to stream the chat completion")
|
| 27 |
|
| 28 |
|
|
|
|
|
|
|
| 29 |
class MessageResponse(BaseModel):
|
| 30 |
"""
|
| 31 |
A chat completion message generated by the model.
|
app/security/auth_service.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
from app.config.security_config import get_security_config
|
| 2 |
-
from fastapi import HTTPException, status, Security
|
| 3 |
from fastapi.security import APIKeyHeader
|
| 4 |
from loguru import logger
|
| 5 |
import base64
|
|
|
|
| 1 |
from app.config.security_config import get_security_config
|
| 2 |
+
from fastapi import HTTPException, status, Security
|
| 3 |
from fastapi.security import APIKeyHeader
|
| 4 |
from loguru import logger
|
| 5 |
import base64
|
app/service/chat_service.py
CHANGED
|
@@ -51,7 +51,13 @@ class ChatService:
|
|
| 51 |
messages = await self.chat_repository.find_messages(completion_id)
|
| 52 |
logger.debug(f"END SERVICE: find_messages for completion_id: {completion_id}, messages: {len(messages)}")
|
| 53 |
messages_response = [
|
| 54 |
-
MessageResponse(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
for message in messages
|
| 56 |
]
|
| 57 |
return messages_response
|
|
@@ -85,7 +91,7 @@ class ChatService:
|
|
| 85 |
figure = await self.chat_repository.find_plot_by_message(completion_id, message_id)
|
| 86 |
|
| 87 |
if figure:
|
| 88 |
-
result =
|
| 89 |
else:
|
| 90 |
result = None
|
| 91 |
logger.warning(f"END SERVICE: no figure found for completion_id: {completion_id}, message_id: {message_id}")
|
|
|
|
| 51 |
messages = await self.chat_repository.find_messages(completion_id)
|
| 52 |
logger.debug(f"END SERVICE: find_messages for completion_id: {completion_id}, messages: {len(messages)}")
|
| 53 |
messages_response = [
|
| 54 |
+
MessageResponse(
|
| 55 |
+
message_id=message.message_id,
|
| 56 |
+
role=message.role,
|
| 57 |
+
content=message.content,
|
| 58 |
+
created_date=message.created_date,
|
| 59 |
+
figure=(message.figure),
|
| 60 |
+
)
|
| 61 |
for message in messages
|
| 62 |
]
|
| 63 |
return messages_response
|
|
|
|
| 91 |
figure = await self.chat_repository.find_plot_by_message(completion_id, message_id)
|
| 92 |
|
| 93 |
if figure:
|
| 94 |
+
result = figure
|
| 95 |
else:
|
| 96 |
result = None
|
| 97 |
logger.warning(f"END SERVICE: no figure found for completion_id: {completion_id}, message_id: {message_id}")
|