"""API endpoints for RAG-powered chat""" from typing import List, Optional from fastapi import APIRouter, Depends, HTTPException from pydantic import BaseModel from sqlalchemy import Float, Integer, String, bindparam, text from sqlalchemy.ext.asyncio import AsyncSession from backend.database import get_db from backend.services.openai_service import OpenAIService router = APIRouter(prefix="/api/chat", tags=["chat"]) class ChatRequest(BaseModel): question: str top_k: int = 7 # Increased for better coverage with 59 chunks similarity_threshold: float = 0.55 # Balanced threshold for quality results system_prompt: Optional[str] = None temperature: float = 0.6 # Lowered for more accurate responses class Citation(BaseModel): chunk_id: int doc_id: int document_title: str text: str similarity_score: float class ChatResponse(BaseModel): question: str answer: str citations: List[Citation] total_citations: int @router.post("/", response_model=ChatResponse) async def chat_with_rag(request: ChatRequest, db: AsyncSession = Depends(get_db)): """ Answer questions using RAG (Retrieval-Augmented Generation) Process: 1. Generate embedding for question 2. Search for similar chunks 3. Construct context from top results 4. Generate answer using LLM with context 5. Return answer with citations Args: request: Chat question and parameters db: Database session Returns: ChatResponse with answer and citations """ try: # Initialize OpenAI service openai_service = OpenAIService() # 1. Generate embedding for question query_embedding = await openai_service.create_embedding(request.question) # Convert embedding to PostgreSQL array format embedding_str = "[" + ",".join(map(str, query_embedding)) + "]" # 2. Search for similar chunks # Use bindparam for proper parameter binding with asyncpg query_sql = text( """ SELECT c.id as chunk_id, c.doc_id, c.text, d.title as document_title, 1 - (e.embedding <=> CAST(:query_embedding AS vector)) as similarity_score FROM chunks c JOIN embeddings e ON c.id = e.chunk_id JOIN documents d ON c.doc_id = d.id WHERE 1 - (e.embedding <=> CAST(:query_embedding AS vector)) >= :threshold ORDER BY e.embedding <=> CAST(:query_embedding AS vector) LIMIT :top_k """ ).bindparams( bindparam("query_embedding", type_=String), bindparam("threshold", type_=Float), bindparam("top_k", type_=Integer), ) result = await db.execute( query_sql, { "query_embedding": embedding_str, "threshold": request.similarity_threshold, "top_k": request.top_k, }, ) rows = result.fetchall() if not rows: return ChatResponse( question=request.question, answer=( "抱歉,我在文檔中找不到足夠的相關資訊來回答這個問題。\n\n" "建議:\n" "1. 請嘗試更具體地描述您的問題\n" "2. 使用職涯相關的關鍵詞(例如:職涯發展、生涯規劃、諮詢技巧等)\n" "3. 確認相關文檔已上傳" ), citations=[], total_citations=0, ) # 3. Construct context from top results context_parts = [] citations = [] for idx, row in enumerate(rows): context_parts.append(f"[{idx + 1}] {row.text}") citations.append( Citation( chunk_id=row.chunk_id, doc_id=row.doc_id, document_title=row.document_title, text=row.text, similarity_score=float(row.similarity_score), ) ) context = "\n\n".join(context_parts) # 4. Generate answer using LLM with context system_prompt = request.system_prompt or ( "你是一位專業的職涯諮詢師助理。請根據提供的文檔內容回答問題。\n\n" "回答時請:\n" "1. 以提供的文檔內容為主要依據,並使用引用編號 [1], [2] 等標示資訊來源\n" "2. 保持專業、客觀、同理的語氣\n" "3. 如果文檔中沒有相關資訊,請明確說明,不要猜測或編造\n" "4. 適當引用文檔中的關鍵概念、理論和實務做法\n" "5. 考慮個別差異和多元觀點\n" "6. 使用與問題相同的語言回答(繁體中文或英文)" ) answer = await openai_service.chat_completion_with_context( question=request.question, context=context, system_prompt=system_prompt, temperature=request.temperature, ) # 5. Return answer with citations return ChatResponse( question=request.question, answer=answer, citations=citations, total_citations=len(citations), ) except Exception as e: raise HTTPException(status_code=500, detail=f"Chat failed: {str(e)}") from e