feat: initial implement end-to-end business with mock-chat-agent-client response and validation logic.
Browse files- app/agent/chat_agent_client.py +14 -0
- app/agent/chat_agent_scheme.py +10 -0
- app/api/chat_api.py +3 -3
- app/mapper/chat_mapper.py +15 -6
- app/model/chat_model.py +4 -4
- app/repository/chat_repository.py +9 -18
- app/schema/chat_schema.py +5 -5
- app/service/chat_service.py +107 -31
- app/service/chat_validation.py +11 -0
- gradio_chatbot.py +11 -11
app/agent/chat_agent_client.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from app.agent.chat_agent_scheme import UserChatAgentRequest, AssistantChatAgentResponse
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class ChatAgentClient:
|
| 5 |
+
def __init__(self):
|
| 6 |
+
self.agent_name = "ChatAgentClient"
|
| 7 |
+
|
| 8 |
+
def process(self, user_chat_agent_request: UserChatAgentRequest) -> AssistantChatAgentResponse:
|
| 9 |
+
# TODO implement the logic to process the chat
|
| 10 |
+
agent_name = self.agent_name
|
| 11 |
+
return AssistantChatAgentResponse(
|
| 12 |
+
message=f"Here is the {agent_name} Processed message: This is a placeholder response for the user-question",
|
| 13 |
+
figure=None, # Placeholder for any figure data if needed
|
| 14 |
+
)
|
app/agent/chat_agent_scheme.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class UserChatAgentRequest(BaseModel):
|
| 5 |
+
message: str
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class AssistantChatAgentResponse(BaseModel):
|
| 9 |
+
message: str
|
| 10 |
+
figure: dict | None = None
|
app/api/chat_api.py
CHANGED
|
@@ -3,7 +3,7 @@
|
|
| 3 |
from typing import Any, List, Optional
|
| 4 |
from fastapi import APIRouter, HTTPException, Depends, Request
|
| 5 |
from pydantic import BaseModel
|
| 6 |
-
from app.schema.chat_schema import ChatCompletionRequest, ChatCompletionResponse,
|
| 7 |
from app.schema.conversation_schema import ConversationResponse, ConversationItemResponse
|
| 8 |
from app.service.chat_service import ChatService
|
| 9 |
from app.security.auth_service import AuthService
|
|
@@ -57,7 +57,7 @@ async def list_chat_completions(request: Request, username: str = Depends(auth_s
|
|
| 57 |
page: int = 1
|
| 58 |
limit: int = 10
|
| 59 |
sort: dict = {"created_date": -1}
|
| 60 |
-
project: dict =
|
| 61 |
|
| 62 |
try:
|
| 63 |
query = {"created_by": username}
|
|
@@ -80,7 +80,7 @@ async def retrieve_chat_completion(completion_id: str, request: Request, usernam
|
|
| 80 |
|
| 81 |
|
| 82 |
# get all messages for a chat completion
|
| 83 |
-
@router.get("/chat/completions/{completion_id}/messages", response_model=List[
|
| 84 |
async def list_messages(completion_id: str, request: Request, username: str = Depends(auth_service.verify_credentials)):
|
| 85 |
"""
|
| 86 |
Get all messages for a chat completion
|
|
|
|
| 3 |
from typing import Any, List, Optional
|
| 4 |
from fastapi import APIRouter, HTTPException, Depends, Request
|
| 5 |
from pydantic import BaseModel
|
| 6 |
+
from app.schema.chat_schema import ChatCompletionRequest, ChatCompletionResponse, ChatMessageResponse
|
| 7 |
from app.schema.conversation_schema import ConversationResponse, ConversationItemResponse
|
| 8 |
from app.service.chat_service import ChatService
|
| 9 |
from app.security.auth_service import AuthService
|
|
|
|
| 57 |
page: int = 1
|
| 58 |
limit: int = 10
|
| 59 |
sort: dict = {"created_date": -1}
|
| 60 |
+
project: dict = {}
|
| 61 |
|
| 62 |
try:
|
| 63 |
query = {"created_by": username}
|
|
|
|
| 80 |
|
| 81 |
|
| 82 |
# get all messages for a chat completion
|
| 83 |
+
@router.get("/chat/completions/{completion_id}/messages", response_model=List[ChatMessageResponse], deprecated=True)
|
| 84 |
async def list_messages(completion_id: str, request: Request, username: str = Depends(auth_service.verify_credentials)):
|
| 85 |
"""
|
| 86 |
Get all messages for a chat completion
|
app/mapper/chat_mapper.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
from app.mapper.base_mapper import BaseMapper
|
| 2 |
-
from app.model.chat_model import ChatCompletion,
|
| 3 |
-
from app.schema.chat_schema import ChatCompletionResponse, ChatCompletionRequest,
|
| 4 |
|
| 5 |
|
| 6 |
class ChatMapper(BaseMapper[ChatCompletion, ChatCompletionResponse]):
|
|
@@ -15,11 +15,12 @@ class ChatMapper(BaseMapper[ChatCompletion, ChatCompletionResponse]):
|
|
| 15 |
last_message = model.messages[-1] if model.messages else None
|
| 16 |
message_response = None
|
| 17 |
if last_message:
|
| 18 |
-
message_response =
|
| 19 |
message_id=last_message.message_id,
|
| 20 |
role=last_message.role,
|
| 21 |
content=last_message.content,
|
| 22 |
figure=last_message.figure,
|
|
|
|
| 23 |
)
|
| 24 |
|
| 25 |
# Create choice response
|
|
@@ -35,11 +36,19 @@ class ChatMapper(BaseMapper[ChatCompletion, ChatCompletionResponse]):
|
|
| 35 |
messages = []
|
| 36 |
if schema.messages:
|
| 37 |
for msg in schema.messages:
|
| 38 |
-
messages.append(
|
| 39 |
|
| 40 |
return ChatCompletion(
|
|
|
|
| 41 |
completion_id=schema.completion_id,
|
| 42 |
-
model=schema.model
|
| 43 |
messages=messages,
|
| 44 |
-
stream=schema.stream
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
)
|
|
|
|
| 1 |
from app.mapper.base_mapper import BaseMapper
|
| 2 |
+
from app.model.chat_model import ChatCompletion, ChatMessageModel
|
| 3 |
+
from app.schema.chat_schema import ChatCompletionResponse, ChatCompletionRequest, ChatMessageResponse, ChoiceResponse
|
| 4 |
|
| 5 |
|
| 6 |
class ChatMapper(BaseMapper[ChatCompletion, ChatCompletionResponse]):
|
|
|
|
| 15 |
last_message = model.messages[-1] if model.messages else None
|
| 16 |
message_response = None
|
| 17 |
if last_message:
|
| 18 |
+
message_response = ChatMessageResponse(
|
| 19 |
message_id=last_message.message_id,
|
| 20 |
role=last_message.role,
|
| 21 |
content=last_message.content,
|
| 22 |
figure=last_message.figure,
|
| 23 |
+
created_date=last_message.created_date,
|
| 24 |
)
|
| 25 |
|
| 26 |
# Create choice response
|
|
|
|
| 36 |
messages = []
|
| 37 |
if schema.messages:
|
| 38 |
for msg in schema.messages:
|
| 39 |
+
messages.append(ChatMessageModel(role=msg.role, content=msg.content, figure=None, message_id=None))
|
| 40 |
|
| 41 |
return ChatCompletion(
|
| 42 |
+
title=schema.title,
|
| 43 |
completion_id=schema.completion_id,
|
| 44 |
+
model=schema.model,
|
| 45 |
messages=messages,
|
| 46 |
+
stream=schema.stream,
|
| 47 |
+
created_by=schema.created_by,
|
| 48 |
+
created_date=schema.created_date,
|
| 49 |
+
object_field=schema.object_field,
|
| 50 |
+
is_archived=schema.archived,
|
| 51 |
+
is_starred=schema.starred,
|
| 52 |
+
last_updated_by=schema.last_updated_by,
|
| 53 |
+
last_updated_date=schema.last_updated_date,
|
| 54 |
)
|
app/model/chat_model.py
CHANGED
|
@@ -51,13 +51,13 @@ from typing import List, Optional, Any
|
|
| 51 |
# }
|
| 52 |
|
| 53 |
|
| 54 |
-
class
|
| 55 |
"""
|
| 56 |
A message in a chat completion.
|
| 57 |
"""
|
| 58 |
|
| 59 |
-
message_id: str = Field(
|
| 60 |
-
role:
|
| 61 |
content: str = Field(..., description="The content of the message")
|
| 62 |
figure: Optional[dict[str, Any]] = Field(None, description="The figure data for visualization")
|
| 63 |
created_date: datetime = Field(default_factory=datetime.now, description="The timestamp of the message")
|
|
@@ -89,7 +89,7 @@ class ChatCompletion(BaseModel):
|
|
| 89 |
|
| 90 |
# openai compatible fields
|
| 91 |
model: Optional[str] = Field(None, description="The model used for the chat completion", examples=["gpt-4o-mini", "gpt-4o", "gpt-3.5-turbo"])
|
| 92 |
-
messages: Optional[List[
|
| 93 |
|
| 94 |
# not implemented yet
|
| 95 |
# temperature: float = Field(default=0.7,ge=0.0, le=1.0, description="What sampling temperature to use, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.")
|
|
|
|
| 51 |
# }
|
| 52 |
|
| 53 |
|
| 54 |
+
class ChatMessageModel(BaseModel):
|
| 55 |
"""
|
| 56 |
A message in a chat completion.
|
| 57 |
"""
|
| 58 |
|
| 59 |
+
message_id: Optional[str] = Field(None, description="The unique identifier for the message")
|
| 60 |
+
role: str = Field(..., description="The role of the message sender", examples=["user", "assistant", "system"])
|
| 61 |
content: str = Field(..., description="The content of the message")
|
| 62 |
figure: Optional[dict[str, Any]] = Field(None, description="The figure data for visualization")
|
| 63 |
created_date: datetime = Field(default_factory=datetime.now, description="The timestamp of the message")
|
|
|
|
| 89 |
|
| 90 |
# openai compatible fields
|
| 91 |
model: Optional[str] = Field(None, description="The model used for the chat completion", examples=["gpt-4o-mini", "gpt-4o", "gpt-3.5-turbo"])
|
| 92 |
+
messages: Optional[List[ChatMessageModel]] = Field(None, description="The messages in the chat completion")
|
| 93 |
|
| 94 |
# not implemented yet
|
| 95 |
# temperature: float = Field(default=0.7,ge=0.0, le=1.0, description="What sampling temperature to use, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.")
|
app/repository/chat_repository.py
CHANGED
|
@@ -1,8 +1,7 @@
|
|
| 1 |
from typing import Any, List, Optional
|
| 2 |
from app.db.factory import db_client
|
| 3 |
-
from app.model.chat_model import
|
| 4 |
from loguru import logger
|
| 5 |
-
import uuid
|
| 6 |
import pymongo
|
| 7 |
|
| 8 |
|
|
@@ -21,7 +20,7 @@ class ChatRepository:
|
|
| 21 |
self.db = db_client.db
|
| 22 |
self.collection = "chat_completion"
|
| 23 |
|
| 24 |
-
async def
|
| 25 |
"""
|
| 26 |
Create a new chat completion in the database.
|
| 27 |
|
|
@@ -48,7 +47,7 @@ class ChatRepository:
|
|
| 48 |
logger.info(f"Successfully created new chat completion with ID: {entity.completion_id}")
|
| 49 |
return await self.find_by_id(entity.completion_id)
|
| 50 |
|
| 51 |
-
async def
|
| 52 |
"""
|
| 53 |
Update an existing chat completion in the database.
|
| 54 |
|
|
@@ -108,9 +107,9 @@ class ChatRepository:
|
|
| 108 |
try:
|
| 109 |
result = await self.find_by_id(entity.completion_id)
|
| 110 |
if result:
|
| 111 |
-
return await self.
|
| 112 |
else:
|
| 113 |
-
return await self.
|
| 114 |
except Exception as e:
|
| 115 |
logger.error(f"Error saving chat completion: {e}")
|
| 116 |
raise
|
|
@@ -118,7 +117,7 @@ class ChatRepository:
|
|
| 118 |
logger.debug("END REPO: save chat completion")
|
| 119 |
|
| 120 |
async def find(
|
| 121 |
-
self, query: dict = {}, page: int = 1, limit: int = 10, sort: dict = {"created_date": -1}, projection: dict =
|
| 122 |
) -> List[ChatCompletion]:
|
| 123 |
"""
|
| 124 |
Find a chat completion by a given query. with pagination
|
|
@@ -147,7 +146,7 @@ class ChatRepository:
|
|
| 147 |
logger.debug(f"END REPO: find, returning {len(result_models)} models.")
|
| 148 |
return result_models
|
| 149 |
|
| 150 |
-
async def find_by_id(self, completion_id: str, projection: dict = None) -> ChatCompletion:
|
| 151 |
"""
|
| 152 |
Find a chat completion by a given id.
|
| 153 |
Example : completion_id = "123"
|
|
@@ -169,7 +168,7 @@ class ChatRepository:
|
|
| 169 |
logger.info(f"Chat completion with ID {completion_id} not found in DB.")
|
| 170 |
return None
|
| 171 |
|
| 172 |
-
async def find_messages(self, completion_id: str) -> List[
|
| 173 |
"""
|
| 174 |
Find all messages for a given chat completion id.
|
| 175 |
Example : completion_id = "123"
|
|
@@ -181,7 +180,7 @@ class ChatRepository:
|
|
| 181 |
if chat_doc and "messages" in chat_doc and chat_doc["messages"]:
|
| 182 |
try:
|
| 183 |
messages_list = [
|
| 184 |
-
|
| 185 |
message_id=item["message_id"],
|
| 186 |
role=item["role"],
|
| 187 |
content=item["content"],
|
|
@@ -215,16 +214,8 @@ class ChatRepository:
|
|
| 215 |
logger.error(f"Error finding plot by message id: {e}")
|
| 216 |
return None
|
| 217 |
|
| 218 |
-
# Mesajları Python tarafında filtreleyelim
|
| 219 |
if entity_doc and "messages" in entity_doc and entity_doc["messages"]:
|
| 220 |
try:
|
| 221 |
-
# İstenen message_id'ye sahip mesajı bul
|
| 222 |
-
# for message in entity_doc["messages"]:
|
| 223 |
-
# if message["message_id"] == message_id:
|
| 224 |
-
# figure = message.get("figure")
|
| 225 |
-
# logger.debug(f"REPO find figure: {figure}")
|
| 226 |
-
# return figure
|
| 227 |
-
|
| 228 |
match = next((message for message in entity_doc["messages"] if message["message_id"] == message_id), None)
|
| 229 |
if match:
|
| 230 |
figure = match.get("figure")
|
|
|
|
| 1 |
from typing import Any, List, Optional
|
| 2 |
from app.db.factory import db_client
|
| 3 |
+
from app.model.chat_model import ChatMessageModel, ChatCompletion
|
| 4 |
from loguru import logger
|
|
|
|
| 5 |
import pymongo
|
| 6 |
|
| 7 |
|
|
|
|
| 20 |
self.db = db_client.db
|
| 21 |
self.collection = "chat_completion"
|
| 22 |
|
| 23 |
+
async def create(self, entity: ChatCompletion) -> ChatCompletion:
|
| 24 |
"""
|
| 25 |
Create a new chat completion in the database.
|
| 26 |
|
|
|
|
| 47 |
logger.info(f"Successfully created new chat completion with ID: {entity.completion_id}")
|
| 48 |
return await self.find_by_id(entity.completion_id)
|
| 49 |
|
| 50 |
+
async def update(self, entity: ChatCompletion) -> ChatCompletion:
|
| 51 |
"""
|
| 52 |
Update an existing chat completion in the database.
|
| 53 |
|
|
|
|
| 107 |
try:
|
| 108 |
result = await self.find_by_id(entity.completion_id)
|
| 109 |
if result:
|
| 110 |
+
return await self.update(entity)
|
| 111 |
else:
|
| 112 |
+
return await self.create(entity)
|
| 113 |
except Exception as e:
|
| 114 |
logger.error(f"Error saving chat completion: {e}")
|
| 115 |
raise
|
|
|
|
| 117 |
logger.debug("END REPO: save chat completion")
|
| 118 |
|
| 119 |
async def find(
|
| 120 |
+
self, query: dict = {}, page: int = 1, limit: int = 10, sort: dict = {"created_date": -1}, projection: dict = {}
|
| 121 |
) -> List[ChatCompletion]:
|
| 122 |
"""
|
| 123 |
Find a chat completion by a given query. with pagination
|
|
|
|
| 146 |
logger.debug(f"END REPO: find, returning {len(result_models)} models.")
|
| 147 |
return result_models
|
| 148 |
|
| 149 |
+
async def find_by_id(self, completion_id: str, projection: dict = None) -> ChatCompletion | None:
|
| 150 |
"""
|
| 151 |
Find a chat completion by a given id.
|
| 152 |
Example : completion_id = "123"
|
|
|
|
| 168 |
logger.info(f"Chat completion with ID {completion_id} not found in DB.")
|
| 169 |
return None
|
| 170 |
|
| 171 |
+
async def find_messages(self, completion_id: str) -> List[ChatMessageModel]:
|
| 172 |
"""
|
| 173 |
Find all messages for a given chat completion id.
|
| 174 |
Example : completion_id = "123"
|
|
|
|
| 180 |
if chat_doc and "messages" in chat_doc and chat_doc["messages"]:
|
| 181 |
try:
|
| 182 |
messages_list = [
|
| 183 |
+
ChatMessageModel(
|
| 184 |
message_id=item["message_id"],
|
| 185 |
role=item["role"],
|
| 186 |
content=item["content"],
|
|
|
|
| 214 |
logger.error(f"Error finding plot by message id: {e}")
|
| 215 |
return None
|
| 216 |
|
|
|
|
| 217 |
if entity_doc and "messages" in entity_doc and entity_doc["messages"]:
|
| 218 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
match = next((message for message in entity_doc["messages"] if message["message_id"] == message_id), None)
|
| 220 |
if match:
|
| 221 |
figure = match.get("figure")
|
app/schema/chat_schema.py
CHANGED
|
@@ -3,7 +3,7 @@ from pydantic import BaseModel, Field
|
|
| 3 |
from datetime import datetime
|
| 4 |
|
| 5 |
|
| 6 |
-
class
|
| 7 |
"""
|
| 8 |
Represents a message in a chat completion.
|
| 9 |
"""
|
|
@@ -22,11 +22,11 @@ class ChatCompletionRequest(BaseModel):
|
|
| 22 |
description="The unique identifier for the chat completion. When starting a new chat, this will be a new UUID. When continuing a previous chat, this will be the same as the previous chat completion id.",
|
| 23 |
)
|
| 24 |
model: Optional[str] = Field(None, description="The model to use for the chat completion")
|
| 25 |
-
messages: Optional[List[
|
| 26 |
stream: Optional[bool] = Field(None, description="Whether to stream the chat completion")
|
| 27 |
|
| 28 |
|
| 29 |
-
class
|
| 30 |
"""
|
| 31 |
A chat completion message generated by the model.
|
| 32 |
"""
|
|
@@ -45,7 +45,7 @@ class ChoiceResponse(BaseModel):
|
|
| 45 |
examples=["stop", "length", "content_filter"],
|
| 46 |
)
|
| 47 |
index: Optional[int] = Field(None, description="The index of the choice in the list of choices.")
|
| 48 |
-
message: Optional[
|
| 49 |
# logprobs: str = None # not implemented yet
|
| 50 |
|
| 51 |
|
|
@@ -54,7 +54,7 @@ class ChatCompletionResponse(BaseModel):
|
|
| 54 |
Represents a chat completion response returned by model, based on the provided input.
|
| 55 |
"""
|
| 56 |
|
| 57 |
-
completion_id:
|
| 58 |
choices: Optional[List[ChoiceResponse]] = Field(None, description="A list of chat completion choices.")
|
| 59 |
created: Optional[int] = Field(None, description="The Unix timestamp (in seconds) of when the chat completion was created.")
|
| 60 |
model: Optional[str] = Field(None, description="The model used for the chat completion")
|
|
|
|
| 3 |
from datetime import datetime
|
| 4 |
|
| 5 |
|
| 6 |
+
class ChatMessageRequest(BaseModel):
|
| 7 |
"""
|
| 8 |
Represents a message in a chat completion.
|
| 9 |
"""
|
|
|
|
| 22 |
description="The unique identifier for the chat completion. When starting a new chat, this will be a new UUID. When continuing a previous chat, this will be the same as the previous chat completion id.",
|
| 23 |
)
|
| 24 |
model: Optional[str] = Field(None, description="The model to use for the chat completion")
|
| 25 |
+
messages: Optional[List[ChatMessageRequest]] = Field(None, description="The messages to use for the chat completion")
|
| 26 |
stream: Optional[bool] = Field(None, description="Whether to stream the chat completion")
|
| 27 |
|
| 28 |
|
| 29 |
+
class ChatMessageResponse(BaseModel):
|
| 30 |
"""
|
| 31 |
A chat completion message generated by the model.
|
| 32 |
"""
|
|
|
|
| 45 |
examples=["stop", "length", "content_filter"],
|
| 46 |
)
|
| 47 |
index: Optional[int] = Field(None, description="The index of the choice in the list of choices.")
|
| 48 |
+
message: Optional[ChatMessageResponse] = Field(None, description="The message to use for the chat completion")
|
| 49 |
# logprobs: str = None # not implemented yet
|
| 50 |
|
| 51 |
|
|
|
|
| 54 |
Represents a chat completion response returned by model, based on the provided input.
|
| 55 |
"""
|
| 56 |
|
| 57 |
+
completion_id: str = Field(None, description="The unique identifier for the chat completion")
|
| 58 |
choices: Optional[List[ChoiceResponse]] = Field(None, description="A list of chat completion choices.")
|
| 59 |
created: Optional[int] = Field(None, description="The Unix timestamp (in seconds) of when the chat completion was created.")
|
| 60 |
model: Optional[str] = Field(None, description="The model used for the chat completion")
|
app/service/chat_service.py
CHANGED
|
@@ -1,12 +1,17 @@
|
|
| 1 |
import datetime
|
| 2 |
-
from typing import Any, List
|
|
|
|
|
|
|
|
|
|
| 3 |
from app.repository.chat_repository import ChatRepository
|
| 4 |
-
from app.schema.chat_schema import ChatCompletionRequest, ChatCompletionResponse,
|
| 5 |
from app.mapper.chat_mapper import ChatMapper
|
| 6 |
from app.mapper.conversation_mapper import ConversationMapper
|
| 7 |
import uuid
|
| 8 |
from loguru import logger
|
| 9 |
from app.schema.conversation_schema import ConversationResponse
|
|
|
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
class ChatService:
|
|
@@ -14,6 +19,8 @@ class ChatService:
|
|
| 14 |
self.chat_repository = ChatRepository()
|
| 15 |
self.chat_mapper = ChatMapper()
|
| 16 |
self.conversation_mapper = ConversationMapper()
|
|
|
|
|
|
|
| 17 |
|
| 18 |
async def find(self, query: dict, page: int, limit: int, sort: dict, project: dict = None) -> List[ChatCompletionResponse]:
|
| 19 |
logger.debug(f"BEGIN SERVICE: find for query: {query}, page: {page}, limit: {limit}, sort: {sort}, project: {project}")
|
|
@@ -24,12 +31,12 @@ class ChatService:
|
|
| 24 |
entity = await self.chat_repository.find_by_id(completion_id, project)
|
| 25 |
return self.chat_mapper.to_schema(entity) if entity else None
|
| 26 |
|
| 27 |
-
async def find_messages(self, completion_id: str) -> List[
|
| 28 |
logger.debug(f"BEGIN SERVICE: find_messages for completion_id: {completion_id}")
|
| 29 |
messages = await self.chat_repository.find_messages(completion_id)
|
| 30 |
logger.debug(f"END SERVICE: find_messages for completion_id: {completion_id}, messages: {len(messages)}")
|
| 31 |
messages_response = [
|
| 32 |
-
|
| 33 |
message_id=message.message_id,
|
| 34 |
role=message.role,
|
| 35 |
content=message.content,
|
|
@@ -41,7 +48,7 @@ class ChatService:
|
|
| 41 |
return messages_response
|
| 42 |
|
| 43 |
# conversation service
|
| 44 |
-
async def find_all_conversations(self, username: str) ->
|
| 45 |
"""Find all conversations for a given username."""
|
| 46 |
query = {"created_by": username}
|
| 47 |
sort = {"last_updated_date": -1} # Sort by last updated date in descending order
|
|
@@ -50,7 +57,8 @@ class ChatService:
|
|
| 50 |
result = self.conversation_mapper.to_schema_list(entities)
|
| 51 |
return ConversationResponse(items=result, total=len(result), limit=100, offset=0)
|
| 52 |
|
| 53 |
-
|
|
|
|
| 54 |
"""Find a conversation by its completion ID."""
|
| 55 |
logger.debug(f"BEGIN SERVICE: find_conversation_by_id for completion_id: {completion_id}")
|
| 56 |
projection = {"messages": 0, "_id": 0}
|
|
@@ -58,13 +66,13 @@ class ChatService:
|
|
| 58 |
logger.debug(f"END SERVICE: find_conversation_by_id for completion_id: {completion_id}, entity: {entity}")
|
| 59 |
|
| 60 |
if entity:
|
| 61 |
-
|
| 62 |
-
result =
|
| 63 |
return result
|
| 64 |
else:
|
| 65 |
return None
|
| 66 |
|
| 67 |
-
async def find_plot_by_message(self, completion_id: str, message_id: str) ->
|
| 68 |
logger.debug(f"BEGIN SERVICE: find_plot_by_message for completion_id: {completion_id}, message_id: {message_id}")
|
| 69 |
figure = await self.chat_repository.find_plot_by_message(completion_id, message_id)
|
| 70 |
|
|
@@ -77,26 +85,94 @@ class ChatService:
|
|
| 77 |
logger.debug(f"END SERVICE: find_plot_by_message for completion_id: {completion_id}, message_id: {message_id} with figure")
|
| 78 |
return result
|
| 79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import datetime
|
| 2 |
+
from typing import Any, List
|
| 3 |
+
|
| 4 |
+
from app.agent.chat_agent_scheme import UserChatAgentRequest
|
| 5 |
+
from app.model.chat_model import ChatMessageModel
|
| 6 |
from app.repository.chat_repository import ChatRepository
|
| 7 |
+
from app.schema.chat_schema import ChatCompletionRequest, ChatCompletionResponse, ChatMessageResponse, ChatMessageRequest
|
| 8 |
from app.mapper.chat_mapper import ChatMapper
|
| 9 |
from app.mapper.conversation_mapper import ConversationMapper
|
| 10 |
import uuid
|
| 11 |
from loguru import logger
|
| 12 |
from app.schema.conversation_schema import ConversationResponse
|
| 13 |
+
from app.service.chat_validation import ChatValidation
|
| 14 |
+
from app.agent.chat_agent_client import ChatAgentClient
|
| 15 |
|
| 16 |
|
| 17 |
class ChatService:
|
|
|
|
| 19 |
self.chat_repository = ChatRepository()
|
| 20 |
self.chat_mapper = ChatMapper()
|
| 21 |
self.conversation_mapper = ConversationMapper()
|
| 22 |
+
self.chat_validation = ChatValidation()
|
| 23 |
+
self.chat_agent_client = ChatAgentClient()
|
| 24 |
|
| 25 |
async def find(self, query: dict, page: int, limit: int, sort: dict, project: dict = None) -> List[ChatCompletionResponse]:
|
| 26 |
logger.debug(f"BEGIN SERVICE: find for query: {query}, page: {page}, limit: {limit}, sort: {sort}, project: {project}")
|
|
|
|
| 31 |
entity = await self.chat_repository.find_by_id(completion_id, project)
|
| 32 |
return self.chat_mapper.to_schema(entity) if entity else None
|
| 33 |
|
| 34 |
+
async def find_messages(self, completion_id: str) -> List[ChatMessageResponse]:
|
| 35 |
logger.debug(f"BEGIN SERVICE: find_messages for completion_id: {completion_id}")
|
| 36 |
messages = await self.chat_repository.find_messages(completion_id)
|
| 37 |
logger.debug(f"END SERVICE: find_messages for completion_id: {completion_id}, messages: {len(messages)}")
|
| 38 |
messages_response = [
|
| 39 |
+
ChatMessageResponse(
|
| 40 |
message_id=message.message_id,
|
| 41 |
role=message.role,
|
| 42 |
content=message.content,
|
|
|
|
| 48 |
return messages_response
|
| 49 |
|
| 50 |
# conversation service
|
| 51 |
+
async def find_all_conversations(self, username: str) -> ConversationResponse:
|
| 52 |
"""Find all conversations for a given username."""
|
| 53 |
query = {"created_by": username}
|
| 54 |
sort = {"last_updated_date": -1} # Sort by last updated date in descending order
|
|
|
|
| 57 |
result = self.conversation_mapper.to_schema_list(entities)
|
| 58 |
return ConversationResponse(items=result, total=len(result), limit=100, offset=0)
|
| 59 |
|
| 60 |
+
# conversation service
|
| 61 |
+
async def find_conversation_by_id(self, completion_id: str) -> ConversationResponse | None:
|
| 62 |
"""Find a conversation by its completion ID."""
|
| 63 |
logger.debug(f"BEGIN SERVICE: find_conversation_by_id for completion_id: {completion_id}")
|
| 64 |
projection = {"messages": 0, "_id": 0}
|
|
|
|
| 66 |
logger.debug(f"END SERVICE: find_conversation_by_id for completion_id: {completion_id}, entity: {entity}")
|
| 67 |
|
| 68 |
if entity:
|
| 69 |
+
conversation_item = self.conversation_mapper.to_schema(entity)
|
| 70 |
+
result = ConversationResponse(items=[conversation_item], total=1, limit=1, offset=0)
|
| 71 |
return result
|
| 72 |
else:
|
| 73 |
return None
|
| 74 |
|
| 75 |
+
async def find_plot_by_message(self, completion_id: str, message_id: str) -> dict[str, Any]:
|
| 76 |
logger.debug(f"BEGIN SERVICE: find_plot_by_message for completion_id: {completion_id}, message_id: {message_id}")
|
| 77 |
figure = await self.chat_repository.find_plot_by_message(completion_id, message_id)
|
| 78 |
|
|
|
|
| 85 |
logger.debug(f"END SERVICE: find_plot_by_message for completion_id: {completion_id}, message_id: {message_id} with figure")
|
| 86 |
return result
|
| 87 |
|
| 88 |
+
async def _save_chat_completion(self, request: ChatCompletionRequest, username: str) -> ChatCompletionResponse:
|
| 89 |
+
"""
|
| 90 |
+
Save a chat completion to the database.
|
| 91 |
+
"""
|
| 92 |
+
logger.debug(f"BEGIN SERVICE: for request: {request}, username: {username}")
|
| 93 |
+
try:
|
| 94 |
+
# Convert request to model
|
| 95 |
+
entity = self.chat_mapper.to_model(request)
|
| 96 |
+
|
| 97 |
+
entity.last_updated_by = username
|
| 98 |
+
entity.last_updated_date = datetime.datetime.now()
|
| 99 |
+
if entity.completion_id:
|
| 100 |
+
# generate a new chat completion
|
| 101 |
+
entity.completion_id = str(uuid.uuid4())
|
| 102 |
+
last_user_request_message = request.messages[-1]
|
| 103 |
+
current_entity = await self.chat_repository.find_by_id(entity.completion_id)
|
| 104 |
+
if not current_entity:
|
| 105 |
+
# create new chat completion with new user request message
|
| 106 |
+
entity.created_by = username
|
| 107 |
+
entity.created_date = datetime.datetime.now()
|
| 108 |
+
entity.last_updated_by = username
|
| 109 |
+
entity.last_updated_date = datetime.datetime.now()
|
| 110 |
+
# title can generate with LLM from user request message.content
|
| 111 |
+
entity.title = last_user_request_message.content[:20]
|
| 112 |
+
final_entity = await self.chat_repository.create(entity)
|
| 113 |
+
else:
|
| 114 |
+
# update existing chat completion with new user request message
|
| 115 |
+
|
| 116 |
+
message_model = ChatMessageModel(
|
| 117 |
+
message_id=str(uuid.uuid4()),
|
| 118 |
+
role=last_user_request_message.role,
|
| 119 |
+
content=last_user_request_message.content,
|
| 120 |
+
figure=None,
|
| 121 |
+
created_date=datetime.datetime.now(),
|
| 122 |
+
)
|
| 123 |
+
current_entity.messages.append(message_model)
|
| 124 |
+
current_entity.last_updated_date = datetime.datetime.now()
|
| 125 |
+
final_entity = await self.chat_repository.update(current_entity)
|
| 126 |
+
|
| 127 |
+
# Convert model to response
|
| 128 |
+
result = self.chat_mapper.to_schema(final_entity)
|
| 129 |
+
logger.debug("END SERVICE")
|
| 130 |
+
return result
|
| 131 |
+
except Exception as e:
|
| 132 |
+
logger.error(f"Error saving chat completion: {e}")
|
| 133 |
+
raise
|
| 134 |
+
|
| 135 |
+
async def chat_agent_client_process(self, user_chat_completion: ChatCompletionRequest, username: str):
|
| 136 |
+
logger.debug(f"BEGIN SERVICE: Agentic Chat AI process. username: {username}")
|
| 137 |
+
last_user_message = user_chat_completion.messages[-1].content
|
| 138 |
+
user_chat_agent_request = UserChatAgentRequest(message=last_user_message)
|
| 139 |
+
result = self.chat_agent_client.process(user_chat_agent_request)
|
| 140 |
+
logger.debug("END SERVICE: Agentic Chat AI process")
|
| 141 |
+
return result
|
| 142 |
|
| 143 |
+
async def handle_chat_completion(self, user_chat_completion: ChatCompletionRequest, username: str) -> ChatCompletionResponse:
|
| 144 |
+
last_user_message = user_chat_completion
|
| 145 |
+
logger.debug(f"BEGIN SERVICE: last_user_message: {last_user_message}, username: {username}")
|
| 146 |
+
|
| 147 |
+
# validate user message
|
| 148 |
+
self.chat_validation.validate_request(user_chat_completion)
|
| 149 |
+
|
| 150 |
+
# save user message to database
|
| 151 |
+
logger.info("Saving user message to database")
|
| 152 |
+
repo_user_message = await self._save_chat_completion(user_chat_completion, username)
|
| 153 |
+
logger.info(f"Saved user message to database with completion_id: {repo_user_message.completion_id}")
|
| 154 |
+
|
| 155 |
+
# region agentic-ai process start #########################################################
|
| 156 |
+
try:
|
| 157 |
+
logger.info("Agentic Chat AI process started")
|
| 158 |
+
agent_result = await self.chat_agent_client_process(user_chat_completion, username)
|
| 159 |
+
assistant_message = ChatMessageRequest(role="assistant", content=agent_result.message)
|
| 160 |
+
assistant_chat_completion = user_chat_completion
|
| 161 |
+
assistant_chat_completion.messages = [assistant_message] # replace user messages with assistant message
|
| 162 |
+
logger.info(f"Agentic Chat AI process completed. Part of Assistant Message...: {assistant_message.content[:50]}...")
|
| 163 |
+
except Exception as e:
|
| 164 |
+
logger.error(f"Error agentic-ai process: {e}")
|
| 165 |
+
raise
|
| 166 |
+
# endregion agentic-ai process start ######################################################
|
| 167 |
+
|
| 168 |
+
# validate agent response
|
| 169 |
+
self.chat_validation.validate_response(agent_result)
|
| 170 |
+
|
| 171 |
+
# save assistant message to database
|
| 172 |
+
repo_assistant_message = await self._save_chat_completion(assistant_chat_completion, username)
|
| 173 |
+
|
| 174 |
+
# generate api response with user, agent, db etc... TBD
|
| 175 |
+
result = repo_assistant_message
|
| 176 |
+
|
| 177 |
+
logger.debug("END SERVICE")
|
| 178 |
+
return result
|
app/service/chat_validation.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class ChatValidation:
|
| 2 |
+
def __init__(self):
|
| 3 |
+
pass
|
| 4 |
+
|
| 5 |
+
def validate_request(self, completion):
|
| 6 |
+
# TODO implement request validation logic
|
| 7 |
+
pass
|
| 8 |
+
|
| 9 |
+
def validate_response(self, agent_result):
|
| 10 |
+
# TODO implement response validation logic
|
| 11 |
+
pass
|
gradio_chatbot.py
CHANGED
|
@@ -2,7 +2,7 @@ import json
|
|
| 2 |
import gradio as gr
|
| 3 |
import environs
|
| 4 |
import httpx
|
| 5 |
-
from typing import List, Tuple, Optional
|
| 6 |
from dataclasses import dataclass
|
| 7 |
from enum import Enum
|
| 8 |
import os
|
|
@@ -125,7 +125,7 @@ class MessageStatus(Enum):
|
|
| 125 |
|
| 126 |
|
| 127 |
@dataclass
|
| 128 |
-
class
|
| 129 |
"""Data class for message response"""
|
| 130 |
|
| 131 |
status: MessageStatus
|
|
@@ -142,7 +142,7 @@ class ChatAPI:
|
|
| 142 |
self.api_key = api_key
|
| 143 |
self.endpoint = f"{base_url}/v1/chat/completions"
|
| 144 |
|
| 145 |
-
async def send_message(self, prompt: str) ->
|
| 146 |
"""
|
| 147 |
Send a message to the chat API
|
| 148 |
|
|
@@ -150,7 +150,7 @@ class ChatAPI:
|
|
| 150 |
prompt (str): The message to send
|
| 151 |
|
| 152 |
Returns:
|
| 153 |
-
|
| 154 |
"""
|
| 155 |
logger.trace(f"Calling chat API with prompt: {prompt}")
|
| 156 |
try:
|
|
@@ -169,7 +169,7 @@ class ChatAPI:
|
|
| 169 |
|
| 170 |
if response.status_code != 200:
|
| 171 |
logger.error(f"API Error: {response.text}")
|
| 172 |
-
return
|
| 173 |
status=MessageStatus.ERROR,
|
| 174 |
content="",
|
| 175 |
figure=None,
|
|
@@ -187,14 +187,14 @@ class ChatAPI:
|
|
| 187 |
logger.trace(f"Figure: {figure}")
|
| 188 |
content = message.get("content", "Content not found")
|
| 189 |
logger.trace(f"Last message: {content}")
|
| 190 |
-
return
|
| 191 |
status=MessageStatus.SUCCESS,
|
| 192 |
content=content,
|
| 193 |
figure=figure,
|
| 194 |
)
|
| 195 |
else:
|
| 196 |
logger.error("Invalid API response")
|
| 197 |
-
return
|
| 198 |
status=MessageStatus.ERROR,
|
| 199 |
content="",
|
| 200 |
error="Invalid API response",
|
|
@@ -202,14 +202,14 @@ class ChatAPI:
|
|
| 202 |
|
| 203 |
except httpx.TimeoutException:
|
| 204 |
logger.error("API request timed out")
|
| 205 |
-
return
|
| 206 |
status=MessageStatus.ERROR,
|
| 207 |
content="",
|
| 208 |
error="Request timed out. Please try again.",
|
| 209 |
)
|
| 210 |
except Exception as e:
|
| 211 |
logger.error(f"Error: {str(e)}")
|
| 212 |
-
return
|
| 213 |
status=MessageStatus.ERROR,
|
| 214 |
content="",
|
| 215 |
error=f"Error: {str(e)}",
|
|
@@ -319,13 +319,13 @@ class ChatInterface:
|
|
| 319 |
None,
|
| 320 |
)
|
| 321 |
|
| 322 |
-
def clear_history() ->
|
| 323 |
"""Clear chat history"""
|
| 324 |
return [], "", "Chat cleared.", "", None
|
| 325 |
|
| 326 |
def retry_last_message(
|
| 327 |
history: List[List[str]],
|
| 328 |
-
) ->
|
| 329 |
"""Retry the last message"""
|
| 330 |
if not history:
|
| 331 |
return history, "", "No message to retry.", "", None
|
|
|
|
| 2 |
import gradio as gr
|
| 3 |
import environs
|
| 4 |
import httpx
|
| 5 |
+
from typing import List, Tuple, Optional, Any
|
| 6 |
from dataclasses import dataclass
|
| 7 |
from enum import Enum
|
| 8 |
import os
|
|
|
|
| 125 |
|
| 126 |
|
| 127 |
@dataclass
|
| 128 |
+
class ChatMessageResponse:
|
| 129 |
"""Data class for message response"""
|
| 130 |
|
| 131 |
status: MessageStatus
|
|
|
|
| 142 |
self.api_key = api_key
|
| 143 |
self.endpoint = f"{base_url}/v1/chat/completions"
|
| 144 |
|
| 145 |
+
async def send_message(self, prompt: str) -> ChatMessageResponse:
|
| 146 |
"""
|
| 147 |
Send a message to the chat API
|
| 148 |
|
|
|
|
| 150 |
prompt (str): The message to send
|
| 151 |
|
| 152 |
Returns:
|
| 153 |
+
ChatMessageResponse: The response from the API
|
| 154 |
"""
|
| 155 |
logger.trace(f"Calling chat API with prompt: {prompt}")
|
| 156 |
try:
|
|
|
|
| 169 |
|
| 170 |
if response.status_code != 200:
|
| 171 |
logger.error(f"API Error: {response.text}")
|
| 172 |
+
return ChatMessageResponse(
|
| 173 |
status=MessageStatus.ERROR,
|
| 174 |
content="",
|
| 175 |
figure=None,
|
|
|
|
| 187 |
logger.trace(f"Figure: {figure}")
|
| 188 |
content = message.get("content", "Content not found")
|
| 189 |
logger.trace(f"Last message: {content}")
|
| 190 |
+
return ChatMessageResponse(
|
| 191 |
status=MessageStatus.SUCCESS,
|
| 192 |
content=content,
|
| 193 |
figure=figure,
|
| 194 |
)
|
| 195 |
else:
|
| 196 |
logger.error("Invalid API response")
|
| 197 |
+
return ChatMessageResponse(
|
| 198 |
status=MessageStatus.ERROR,
|
| 199 |
content="",
|
| 200 |
error="Invalid API response",
|
|
|
|
| 202 |
|
| 203 |
except httpx.TimeoutException:
|
| 204 |
logger.error("API request timed out")
|
| 205 |
+
return ChatMessageResponse(
|
| 206 |
status=MessageStatus.ERROR,
|
| 207 |
content="",
|
| 208 |
error="Request timed out. Please try again.",
|
| 209 |
)
|
| 210 |
except Exception as e:
|
| 211 |
logger.error(f"Error: {str(e)}")
|
| 212 |
+
return ChatMessageResponse(
|
| 213 |
status=MessageStatus.ERROR,
|
| 214 |
content="",
|
| 215 |
error=f"Error: {str(e)}",
|
|
|
|
| 319 |
None,
|
| 320 |
)
|
| 321 |
|
| 322 |
+
def clear_history() -> tuple[list[Any], str, str, str, None]:
|
| 323 |
"""Clear chat history"""
|
| 324 |
return [], "", "Chat cleared.", "", None
|
| 325 |
|
| 326 |
def retry_last_message(
|
| 327 |
history: List[List[str]],
|
| 328 |
+
) -> tuple[list[list[str]], str, str, str, None]:
|
| 329 |
"""Retry the last message"""
|
| 330 |
if not history:
|
| 331 |
return history, "", "No message to retry.", "", None
|