silver / chat_service.py
Song
hi
238cf71
"""
Chat service for Silver Table Assistant.
Provides AI-powered chat functionality with RAG integration and user context.
"""
import os
import logging
from typing import List, Dict, Any, Optional, AsyncGenerator
from uuid import UUID
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import asyncio
from sqlalchemy.ext.asyncio import AsyncSession
from rag import get_rag_service
from crud import get_profile
from models import ChatConversation, Profile
from database import get_db_session
from config import settings
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ChatService:
"""Chat service for AI-powered conversations with RAG integration."""
def __init__(self):
"""Initialize chat service with LiteLLM or OpenAI Chat model."""
# Environment variables
self.openai_api_key = os.getenv("OPENAI_API_KEY") or os.getenv("LITELLM_API_KEY", "sk-eT_04m428oAPUD5kUmIhVA")
self.openai_base_url = os.getenv("OPENAI_BASE_URL") or os.getenv("LITELLM_BASE_URL", "https://litellm-ekkks8gsocw.dgx-coolify.apmic.ai/")
if not self.openai_api_key:
raise ValueError("Missing required environment variable: OPENAI_API_KEY or LITELLM_API_KEY")
# Initialize ChatOpenAI model (works with LiteLLM compatible endpoints)
model_kwargs = {
"model": settings.ai_model_name,
"openai_api_key": self.openai_api_key,
"temperature": settings.ai_temperature,
"max_tokens": settings.ai_max_tokens,
}
# Add base_url for LiteLLM or Azure OpenAI
if self.openai_base_url:
model_kwargs["openai_api_base"] = self.openai_base_url
self.llm = ChatOpenAI(**model_kwargs)
logger.info(f"Initialized ChatOpenAI with base_url: {self.openai_base_url}, model: {settings.ai_model_name}")
# System prompt for the silver table assistant
self.system_prompt = self._create_system_prompt()
# RAG service
self.rag_service = get_rag_service()
def _create_system_prompt(self) -> str:
"""
Create the system prompt for the silver table assistant.
Returns:
System prompt string
"""
return """你是「銀髮餐桌助手」,專為台灣銀髮族設計的AI營養飲食顧問助手。
角色定位:
- 你是一位溫暖、耐心、專業的營養飲食顧問
- 專門為台灣銀髮族(65歲以上)提供飲食建議
- 熟悉台灣在地食材、飲食文化和生活習慣
核心原則:
1. 嚴格遵循台灣衛福部(MOHW)的營養指導原則和飲食指南
2. 僅提供營養建議,絕不進行醫療診斷或疾病診斷
3. 針對銀髮族的特殊營養需求(蛋白質、鈣質、維生素D、纖維等)
4. 考慮台灣在地飲食文化和可用食材
5. 語調溫和、耐心,像家中長輩般的關懷
重要限制:
- 絕不提供醫療診斷或疾病治療建議
- 涉及健康問題時,建議諮詢專業醫師
- 不推薦特定品牌或產品
- 基於科學證據和官方營養指南提供建議
回應風格:
- 使用繁體中文
- 語調溫暖親切
- 提供具體可行的建議
- 適時提供鼓勵和關懷
- 考慮使用者的年齡和健康狀況
當使用者詢問營養、飲食、食材選擇、烹調方式等相關問題時,請基於台灣衛福部的營養指導原則回答,並考慮使用者的個人健康狀況(如果有提供的話)。"""
async def get_user_context(self, profile_id: Optional[UUID] = None) -> Dict[str, Any]:
"""
Get user context information for personalized responses.
Args:
profile_id: User profile ID
Returns:
Dictionary with user context information
"""
context = {
"has_profile": False,
"age": None,
"health_conditions": None,
"dietary_restrictions": None
}
if profile_id:
try:
# Get database session
async with get_db_session() as db:
profile = await get_profile(db, profile_id)
if profile:
context.update({
"has_profile": True,
"age": profile.age,
"health_conditions": profile.health_condition,
"dietary_restrictions": profile.dietary_restrictions,
"display_name": profile.display_name
})
logger.info(f"Retrieved profile context for user {profile_id}")
else:
logger.warning(f"Profile not found for ID: {profile_id}")
except Exception as e:
logger.error(f"Error retrieving user profile: {str(e)}")
return context
def format_context_information(self, user_context: Dict[str, Any], relevant_docs: List[Any]) -> str:
"""
Format context information for the AI prompt.
Args:
user_context: User context dictionary
relevant_docs: Relevant documents from RAG
Returns:
Formatted context string
"""
context_parts = []
# Add user context if available
if user_context["has_profile"]:
context_parts.append("使用者背景資訊:")
if user_context.get("display_name"):
context_parts.append(f"- 姓名:{user_context['display_name']}")
if user_context.get("age"):
context_parts.append(f"- 年齡:{user_context['age']}歲")
if user_context.get("health_conditions"):
context_parts.append(f"- 健康狀況:{user_context['health_conditions']}")
if user_context.get("dietary_restrictions"):
context_parts.append(f"- 飲食限制:{user_context['dietary_restrictions']}")
context_parts.append("")
# Add relevant documents
if relevant_docs:
context_parts.append("相關營養指南資訊:")
for i, doc in enumerate(relevant_docs, 1):
source = doc.metadata.get("file_name", "未知來源")
content = doc.page_content.strip()
# Limit content length to avoid token overflow
if len(content) > 500:
content = content[:500] + "..."
context_parts.append(f"{i}. 來源:{source}")
context_parts.append(f" 內容:{content}")
context_parts.append("")
return "\n".join(context_parts)
async def chat_stream(
self,
message: str,
profile_id: Optional[str] = None,
history: List[Dict[str, str]] = None
) -> AsyncGenerator[str, None]:
"""
Stream chat response with context and RAG integration.
Args:
message: User message
profile_id: Optional user profile ID for personalization
history: Chat history messages
Yields:
Response content chunks
"""
try:
# Convert profile_id to UUID if provided
profile_uuid = None
if profile_id:
try:
profile_uuid = UUID(profile_id)
except ValueError:
logger.warning(f"Invalid profile ID format: {profile_id}")
# Get user context
user_context = await self.get_user_context(profile_uuid)
# Get relevant documents from RAG
relevant_docs = await self.rag_service.get_relevant_documents(message, k=6)
# Format context information
context_info = self.format_context_information(user_context, relevant_docs)
# Prepare message history
messages = []
# Add system message with context
if context_info:
system_content = f"{self.system_prompt}\n\n背景資訊:\n{context_info}"
else:
system_content = self.system_prompt
messages.append(SystemMessage(content=system_content))
# Add chat history
if history:
for msg in history:
if msg["role"] == "user":
messages.append(HumanMessage(content=msg["content"]))
elif msg["role"] == "assistant":
messages.append(AIMessage(content=msg["content"]))
# Add current user message
messages.append(HumanMessage(content=message))
# Stream response from LLM
logger.info(f"Generating chat response for message: '{message[:50]}...'")
full_response = ""
async for chunk in self.llm.astream(messages):
if hasattr(chunk, "content") and chunk.content:
full_response += chunk.content
yield chunk.content
# Log the interaction with both message and response
if full_response:
await self._log_conversation(
message=message,
response=full_response,
profile_id=profile_uuid,
user_context=user_context,
relevant_docs_count=len(relevant_docs)
)
except Exception as e:
logger.error(f"Error in chat stream: {str(e)}")
yield "抱歉,系統發生了一些問題。請稍後再試。"
async def _log_conversation(
self,
message: str,
response: str,
profile_id: Optional[UUID],
user_context: Dict[str, Any],
relevant_docs_count: int
) -> None:
"""
Log conversation to database for analytics and improvement.
Args:
message: User message
response: AI response
profile_id: User profile ID
user_context: User context information
relevant_docs_count: Number of relevant documents found
"""
try:
# Prepare metadata
metadata = {
"user_context": user_context,
"relevant_docs_count": relevant_docs_count,
"timestamp": settings.get_current_timestamp()
}
# Log conversation with both message and response
await self._save_conversation(
profile_id=profile_id,
message=message,
response=response,
meta_data=metadata
)
except Exception as e:
logger.error(f"Error logging conversation: {str(e)}")
async def _save_conversation(
self,
profile_id: Optional[UUID],
message: str,
response: Optional[str] = None,
meta_data: Optional[Dict[str, Any]] = None
) -> None:
"""
Save chat conversation to database.
Args:
profile_id: User profile ID
message: User message
response: AI response (optional)
meta_data: Additional metadata
"""
try:
async with get_db_session() as db:
conversation = ChatConversation(
profile_id=profile_id,
message=message,
response=response,
meta_data=meta_data or {}
)
db.add(conversation)
await db.commit()
except Exception as e:
logger.error(f"Error saving conversation to database: {str(e)}")
async def get_chat_history(
self,
profile_id: Optional[UUID],
session_id: str,
limit: int = 50
) -> List[Dict[str, str]]:
"""
Get chat history for a session.
Args:
profile_id: User profile ID
session_id: Chat session ID
limit: Maximum number of messages to return
Returns:
List of chat messages
"""
try:
async with get_db_session() as db:
from sqlalchemy import select
query = (
select(ChatConversation)
.order_by(ChatConversation.created_at.asc())
.limit(limit)
)
if profile_id:
query = query.where(ChatConversation.profile_id == profile_id)
result = await db.execute(query)
conversations = result.scalars().all()
# Return user messages and AI responses
chat_history = []
for conv in conversations:
if conv.message:
chat_history.append({
"role": "user",
"content": conv.message,
"timestamp": conv.created_at.isoformat()
})
if conv.response:
chat_history.append({
"role": "assistant",
"content": conv.response,
"timestamp": conv.created_at.isoformat()
})
return chat_history
except Exception as e:
logger.error(f"Error getting chat history: {str(e)}")
return []
# Global chat service instance
chat_service: Optional[ChatService] = None
def get_chat_service() -> ChatService:
"""
Get or create the global chat service instance.
Returns:
ChatService instance
"""
global chat_service
if chat_service is None:
chat_service = ChatService()
return chat_service
# Convenience function for backward compatibility
async def chat_stream(
message: str,
profile_id: Optional[str] = None,
history: List[Dict[str, str]] = None
) -> AsyncGenerator[str, None]:
"""Stream chat response."""
service = get_chat_service()
async for chunk in service.chat_stream(message, profile_id, history):
yield chunk
if __name__ == "__main__":
"""
Main block for testing chat service functionality.
"""
async def main():
"""Main function for testing."""
print("Testing Chat Service...")
try:
# Initialize chat service
service = get_chat_service()
# Test messages
test_messages = [
"請問銀髮族應該如何補充蛋白質?",
"我爸爸有糖尿病,飲食上有什麼需要注意的?",
"推薦一些適合銀髮族的早餐選項"
]
for i, message in enumerate(test_messages, 1):
print(f"\n--- 測試對話 {i} ---")
print(f"使用者:{message}")
print("助手:", end="", flush=True)
# Stream response
async for chunk in service.chat_stream(message):
print(chunk, end="", flush=True)
print("\n" + "="*50)
except Exception as e:
print(f"Error: {str(e)}")
raise
# Run the main function
asyncio.run(main())