| | """ |
| | Chat service for Silver Table Assistant. |
| | Provides AI-powered chat functionality with RAG integration and user context. |
| | """ |
| |
|
| | import os |
| | import logging |
| | from typing import List, Dict, Any, Optional, AsyncGenerator |
| | from uuid import UUID |
| |
|
| | from langchain_openai import ChatOpenAI |
| | from langchain_core.messages import HumanMessage, SystemMessage, AIMessage |
| | from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
| |
|
| | import asyncio |
| | from sqlalchemy.ext.asyncio import AsyncSession |
| |
|
| | from rag import get_rag_service |
| | from crud import get_profile |
| | from models import ChatConversation, Profile |
| | from database import get_db_session |
| | from config import settings |
| |
|
| | |
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class ChatService: |
| | """Chat service for AI-powered conversations with RAG integration.""" |
| | |
| | def __init__(self): |
| | """Initialize chat service with LiteLLM or OpenAI Chat model.""" |
| | |
| | self.openai_api_key = os.getenv("OPENAI_API_KEY") or os.getenv("LITELLM_API_KEY", "sk-eT_04m428oAPUD5kUmIhVA") |
| | self.openai_base_url = os.getenv("OPENAI_BASE_URL") or os.getenv("LITELLM_BASE_URL", "https://litellm-ekkks8gsocw.dgx-coolify.apmic.ai/") |
| | |
| | if not self.openai_api_key: |
| | raise ValueError("Missing required environment variable: OPENAI_API_KEY or LITELLM_API_KEY") |
| | |
| | |
| | model_kwargs = { |
| | "model": settings.ai_model_name, |
| | "openai_api_key": self.openai_api_key, |
| | "temperature": settings.ai_temperature, |
| | "max_tokens": settings.ai_max_tokens, |
| | } |
| | |
| | |
| | if self.openai_base_url: |
| | model_kwargs["openai_api_base"] = self.openai_base_url |
| | |
| | self.llm = ChatOpenAI(**model_kwargs) |
| | logger.info(f"Initialized ChatOpenAI with base_url: {self.openai_base_url}, model: {settings.ai_model_name}") |
| | |
| | |
| | self.system_prompt = self._create_system_prompt() |
| | |
| | |
| | self.rag_service = get_rag_service() |
| | |
| | def _create_system_prompt(self) -> str: |
| | """ |
| | Create the system prompt for the silver table assistant. |
| | |
| | Returns: |
| | System prompt string |
| | """ |
| | return """你是「銀髮餐桌助手」,專為台灣銀髮族設計的AI營養飲食顧問助手。 |
| | |
| | 角色定位: |
| | - 你是一位溫暖、耐心、專業的營養飲食顧問 |
| | - 專門為台灣銀髮族(65歲以上)提供飲食建議 |
| | - 熟悉台灣在地食材、飲食文化和生活習慣 |
| | |
| | 核心原則: |
| | 1. 嚴格遵循台灣衛福部(MOHW)的營養指導原則和飲食指南 |
| | 2. 僅提供營養建議,絕不進行醫療診斷或疾病診斷 |
| | 3. 針對銀髮族的特殊營養需求(蛋白質、鈣質、維生素D、纖維等) |
| | 4. 考慮台灣在地飲食文化和可用食材 |
| | 5. 語調溫和、耐心,像家中長輩般的關懷 |
| | |
| | 重要限制: |
| | - 絕不提供醫療診斷或疾病治療建議 |
| | - 涉及健康問題時,建議諮詢專業醫師 |
| | - 不推薦特定品牌或產品 |
| | - 基於科學證據和官方營養指南提供建議 |
| | |
| | 回應風格: |
| | - 使用繁體中文 |
| | - 語調溫暖親切 |
| | - 提供具體可行的建議 |
| | - 適時提供鼓勵和關懷 |
| | - 考慮使用者的年齡和健康狀況 |
| | |
| | 當使用者詢問營養、飲食、食材選擇、烹調方式等相關問題時,請基於台灣衛福部的營養指導原則回答,並考慮使用者的個人健康狀況(如果有提供的話)。""" |
| | |
| | async def get_user_context(self, profile_id: Optional[UUID] = None) -> Dict[str, Any]: |
| | """ |
| | Get user context information for personalized responses. |
| | |
| | Args: |
| | profile_id: User profile ID |
| | |
| | Returns: |
| | Dictionary with user context information |
| | """ |
| | context = { |
| | "has_profile": False, |
| | "age": None, |
| | "health_conditions": None, |
| | "dietary_restrictions": None |
| | } |
| | |
| | if profile_id: |
| | try: |
| | |
| | async with get_db_session() as db: |
| | profile = await get_profile(db, profile_id) |
| | if profile: |
| | context.update({ |
| | "has_profile": True, |
| | "age": profile.age, |
| | "health_conditions": profile.health_condition, |
| | "dietary_restrictions": profile.dietary_restrictions, |
| | "display_name": profile.display_name |
| | }) |
| | logger.info(f"Retrieved profile context for user {profile_id}") |
| | else: |
| | logger.warning(f"Profile not found for ID: {profile_id}") |
| | except Exception as e: |
| | logger.error(f"Error retrieving user profile: {str(e)}") |
| | |
| | return context |
| | |
| | def format_context_information(self, user_context: Dict[str, Any], relevant_docs: List[Any]) -> str: |
| | """ |
| | Format context information for the AI prompt. |
| | |
| | Args: |
| | user_context: User context dictionary |
| | relevant_docs: Relevant documents from RAG |
| | |
| | Returns: |
| | Formatted context string |
| | """ |
| | context_parts = [] |
| | |
| | |
| | if user_context["has_profile"]: |
| | context_parts.append("使用者背景資訊:") |
| | if user_context.get("display_name"): |
| | context_parts.append(f"- 姓名:{user_context['display_name']}") |
| | if user_context.get("age"): |
| | context_parts.append(f"- 年齡:{user_context['age']}歲") |
| | if user_context.get("health_conditions"): |
| | context_parts.append(f"- 健康狀況:{user_context['health_conditions']}") |
| | if user_context.get("dietary_restrictions"): |
| | context_parts.append(f"- 飲食限制:{user_context['dietary_restrictions']}") |
| | context_parts.append("") |
| | |
| | |
| | if relevant_docs: |
| | context_parts.append("相關營養指南資訊:") |
| | for i, doc in enumerate(relevant_docs, 1): |
| | source = doc.metadata.get("file_name", "未知來源") |
| | content = doc.page_content.strip() |
| | |
| | if len(content) > 500: |
| | content = content[:500] + "..." |
| | context_parts.append(f"{i}. 來源:{source}") |
| | context_parts.append(f" 內容:{content}") |
| | context_parts.append("") |
| | |
| | return "\n".join(context_parts) |
| | |
| | async def chat_stream( |
| | self, |
| | message: str, |
| | profile_id: Optional[str] = None, |
| | history: List[Dict[str, str]] = None |
| | ) -> AsyncGenerator[str, None]: |
| | """ |
| | Stream chat response with context and RAG integration. |
| | |
| | Args: |
| | message: User message |
| | profile_id: Optional user profile ID for personalization |
| | history: Chat history messages |
| | |
| | Yields: |
| | Response content chunks |
| | """ |
| | try: |
| | |
| | profile_uuid = None |
| | if profile_id: |
| | try: |
| | profile_uuid = UUID(profile_id) |
| | except ValueError: |
| | logger.warning(f"Invalid profile ID format: {profile_id}") |
| | |
| | |
| | user_context = await self.get_user_context(profile_uuid) |
| | |
| | |
| | relevant_docs = await self.rag_service.get_relevant_documents(message, k=6) |
| | |
| | |
| | context_info = self.format_context_information(user_context, relevant_docs) |
| | |
| | |
| | messages = [] |
| | |
| | |
| | if context_info: |
| | system_content = f"{self.system_prompt}\n\n背景資訊:\n{context_info}" |
| | else: |
| | system_content = self.system_prompt |
| | |
| | messages.append(SystemMessage(content=system_content)) |
| | |
| | |
| | if history: |
| | for msg in history: |
| | if msg["role"] == "user": |
| | messages.append(HumanMessage(content=msg["content"])) |
| | elif msg["role"] == "assistant": |
| | messages.append(AIMessage(content=msg["content"])) |
| | |
| | |
| | messages.append(HumanMessage(content=message)) |
| | |
| | |
| | logger.info(f"Generating chat response for message: '{message[:50]}...'") |
| | |
| | full_response = "" |
| | async for chunk in self.llm.astream(messages): |
| | if hasattr(chunk, "content") and chunk.content: |
| | full_response += chunk.content |
| | yield chunk.content |
| | |
| | |
| | if full_response: |
| | await self._log_conversation( |
| | message=message, |
| | response=full_response, |
| | profile_id=profile_uuid, |
| | user_context=user_context, |
| | relevant_docs_count=len(relevant_docs) |
| | ) |
| | |
| | except Exception as e: |
| | logger.error(f"Error in chat stream: {str(e)}") |
| | yield "抱歉,系統發生了一些問題。請稍後再試。" |
| | |
| | async def _log_conversation( |
| | self, |
| | message: str, |
| | response: str, |
| | profile_id: Optional[UUID], |
| | user_context: Dict[str, Any], |
| | relevant_docs_count: int |
| | ) -> None: |
| | """ |
| | Log conversation to database for analytics and improvement. |
| | |
| | Args: |
| | message: User message |
| | response: AI response |
| | profile_id: User profile ID |
| | user_context: User context information |
| | relevant_docs_count: Number of relevant documents found |
| | """ |
| | try: |
| | |
| | metadata = { |
| | "user_context": user_context, |
| | "relevant_docs_count": relevant_docs_count, |
| | "timestamp": settings.get_current_timestamp() |
| | } |
| | |
| | |
| | await self._save_conversation( |
| | profile_id=profile_id, |
| | message=message, |
| | response=response, |
| | meta_data=metadata |
| | ) |
| | |
| | except Exception as e: |
| | logger.error(f"Error logging conversation: {str(e)}") |
| | |
| | async def _save_conversation( |
| | self, |
| | profile_id: Optional[UUID], |
| | message: str, |
| | response: Optional[str] = None, |
| | meta_data: Optional[Dict[str, Any]] = None |
| | ) -> None: |
| | """ |
| | Save chat conversation to database. |
| | |
| | Args: |
| | profile_id: User profile ID |
| | message: User message |
| | response: AI response (optional) |
| | meta_data: Additional metadata |
| | """ |
| | try: |
| | async with get_db_session() as db: |
| | conversation = ChatConversation( |
| | profile_id=profile_id, |
| | message=message, |
| | response=response, |
| | meta_data=meta_data or {} |
| | ) |
| | |
| | db.add(conversation) |
| | await db.commit() |
| | |
| | except Exception as e: |
| | logger.error(f"Error saving conversation to database: {str(e)}") |
| | |
| | async def get_chat_history( |
| | self, |
| | profile_id: Optional[UUID], |
| | session_id: str, |
| | limit: int = 50 |
| | ) -> List[Dict[str, str]]: |
| | """ |
| | Get chat history for a session. |
| | |
| | Args: |
| | profile_id: User profile ID |
| | session_id: Chat session ID |
| | limit: Maximum number of messages to return |
| | |
| | Returns: |
| | List of chat messages |
| | """ |
| | try: |
| | async with get_db_session() as db: |
| | from sqlalchemy import select |
| | |
| | query = ( |
| | select(ChatConversation) |
| | .order_by(ChatConversation.created_at.asc()) |
| | .limit(limit) |
| | ) |
| | |
| | if profile_id: |
| | query = query.where(ChatConversation.profile_id == profile_id) |
| | |
| | result = await db.execute(query) |
| | conversations = result.scalars().all() |
| | |
| | |
| | chat_history = [] |
| | for conv in conversations: |
| | if conv.message: |
| | chat_history.append({ |
| | "role": "user", |
| | "content": conv.message, |
| | "timestamp": conv.created_at.isoformat() |
| | }) |
| | if conv.response: |
| | chat_history.append({ |
| | "role": "assistant", |
| | "content": conv.response, |
| | "timestamp": conv.created_at.isoformat() |
| | }) |
| | |
| | return chat_history |
| | |
| | except Exception as e: |
| | logger.error(f"Error getting chat history: {str(e)}") |
| | return [] |
| |
|
| |
|
| | |
| | chat_service: Optional[ChatService] = None |
| |
|
| |
|
| | def get_chat_service() -> ChatService: |
| | """ |
| | Get or create the global chat service instance. |
| | |
| | Returns: |
| | ChatService instance |
| | """ |
| | global chat_service |
| | if chat_service is None: |
| | chat_service = ChatService() |
| | return chat_service |
| |
|
| |
|
| | |
| | async def chat_stream( |
| | message: str, |
| | profile_id: Optional[str] = None, |
| | history: List[Dict[str, str]] = None |
| | ) -> AsyncGenerator[str, None]: |
| | """Stream chat response.""" |
| | service = get_chat_service() |
| | async for chunk in service.chat_stream(message, profile_id, history): |
| | yield chunk |
| |
|
| |
|
| | if __name__ == "__main__": |
| | """ |
| | Main block for testing chat service functionality. |
| | """ |
| | async def main(): |
| | """Main function for testing.""" |
| | print("Testing Chat Service...") |
| | |
| | try: |
| | |
| | service = get_chat_service() |
| | |
| | |
| | test_messages = [ |
| | "請問銀髮族應該如何補充蛋白質?", |
| | "我爸爸有糖尿病,飲食上有什麼需要注意的?", |
| | "推薦一些適合銀髮族的早餐選項" |
| | ] |
| | |
| | for i, message in enumerate(test_messages, 1): |
| | print(f"\n--- 測試對話 {i} ---") |
| | print(f"使用者:{message}") |
| | print("助手:", end="", flush=True) |
| | |
| | |
| | async for chunk in service.chat_stream(message): |
| | print(chunk, end="", flush=True) |
| | |
| | print("\n" + "="*50) |
| | |
| | except Exception as e: |
| | print(f"Error: {str(e)}") |
| | raise |
| | |
| | |
| | asyncio.run(main()) |