Spaces:
Build error
Build error
| """API endpoints for RAG-powered chat""" | |
| from typing import List, Optional | |
| from fastapi import APIRouter, Depends, HTTPException | |
| from pydantic import BaseModel | |
| from sqlalchemy import Float, Integer, String, bindparam, text | |
| from sqlalchemy.ext.asyncio import AsyncSession | |
| from backend.database import get_db | |
| from backend.services.openai_service import OpenAIService | |
| router = APIRouter(prefix="/api/chat", tags=["chat"]) | |
| class ChatRequest(BaseModel): | |
| question: str | |
| top_k: int = 7 # Increased for better coverage with 59 chunks | |
| similarity_threshold: float = 0.55 # Balanced threshold for quality results | |
| system_prompt: Optional[str] = None | |
| temperature: float = 0.6 # Lowered for more accurate responses | |
| class Citation(BaseModel): | |
| chunk_id: int | |
| doc_id: int | |
| document_title: str | |
| text: str | |
| similarity_score: float | |
| class ChatResponse(BaseModel): | |
| question: str | |
| answer: str | |
| citations: List[Citation] | |
| total_citations: int | |
| async def chat_with_rag(request: ChatRequest, db: AsyncSession = Depends(get_db)): | |
| """ | |
| Answer questions using RAG (Retrieval-Augmented Generation) | |
| Process: | |
| 1. Generate embedding for question | |
| 2. Search for similar chunks | |
| 3. Construct context from top results | |
| 4. Generate answer using LLM with context | |
| 5. Return answer with citations | |
| Args: | |
| request: Chat question and parameters | |
| db: Database session | |
| Returns: | |
| ChatResponse with answer and citations | |
| """ | |
| try: | |
| # Initialize OpenAI service | |
| openai_service = OpenAIService() | |
| # 1. Generate embedding for question | |
| query_embedding = await openai_service.create_embedding(request.question) | |
| # Convert embedding to PostgreSQL array format | |
| embedding_str = "[" + ",".join(map(str, query_embedding)) + "]" | |
| # 2. Search for similar chunks | |
| # Use bindparam for proper parameter binding with asyncpg | |
| query_sql = text( | |
| """ | |
| SELECT | |
| c.id as chunk_id, | |
| c.doc_id, | |
| c.text, | |
| d.title as document_title, | |
| 1 - (e.embedding <=> CAST(:query_embedding AS vector)) as similarity_score | |
| FROM chunks c | |
| JOIN embeddings e ON c.id = e.chunk_id | |
| JOIN documents d ON c.doc_id = d.id | |
| WHERE 1 - (e.embedding <=> CAST(:query_embedding AS vector)) >= :threshold | |
| ORDER BY e.embedding <=> CAST(:query_embedding AS vector) | |
| LIMIT :top_k | |
| """ | |
| ).bindparams( | |
| bindparam("query_embedding", type_=String), | |
| bindparam("threshold", type_=Float), | |
| bindparam("top_k", type_=Integer), | |
| ) | |
| result = await db.execute( | |
| query_sql, | |
| { | |
| "query_embedding": embedding_str, | |
| "threshold": request.similarity_threshold, | |
| "top_k": request.top_k, | |
| }, | |
| ) | |
| rows = result.fetchall() | |
| if not rows: | |
| return ChatResponse( | |
| question=request.question, | |
| answer=( | |
| "抱歉,我在文檔中找不到足夠的相關資訊來回答這個問題。\n\n" | |
| "建議:\n" | |
| "1. 請嘗試更具體地描述您的問題\n" | |
| "2. 使用職涯相關的關鍵詞(例如:職涯發展、生涯規劃、諮詢技巧等)\n" | |
| "3. 確認相關文檔已上傳" | |
| ), | |
| citations=[], | |
| total_citations=0, | |
| ) | |
| # 3. Construct context from top results | |
| context_parts = [] | |
| citations = [] | |
| for idx, row in enumerate(rows): | |
| context_parts.append(f"[{idx + 1}] {row.text}") | |
| citations.append( | |
| Citation( | |
| chunk_id=row.chunk_id, | |
| doc_id=row.doc_id, | |
| document_title=row.document_title, | |
| text=row.text, | |
| similarity_score=float(row.similarity_score), | |
| ) | |
| ) | |
| context = "\n\n".join(context_parts) | |
| # 4. Generate answer using LLM with context | |
| system_prompt = request.system_prompt or ( | |
| "你是一位專業的職涯諮詢師助理。請根據提供的文檔內容回答問題。\n\n" | |
| "回答時請:\n" | |
| "1. 以提供的文檔內容為主要依據,並使用引用編號 [1], [2] 等標示資訊來源\n" | |
| "2. 保持專業、客觀、同理的語氣\n" | |
| "3. 如果文檔中沒有相關資訊,請明確說明,不要猜測或編造\n" | |
| "4. 適當引用文檔中的關鍵概念、理論和實務做法\n" | |
| "5. 考慮個別差異和多元觀點\n" | |
| "6. 使用與問題相同的語言回答(繁體中文或英文)" | |
| ) | |
| answer = await openai_service.chat_completion_with_context( | |
| question=request.question, | |
| context=context, | |
| system_prompt=system_prompt, | |
| temperature=request.temperature, | |
| ) | |
| # 5. Return answer with citations | |
| return ChatResponse( | |
| question=request.question, | |
| answer=answer, | |
| citations=citations, | |
| total_citations=len(citations), | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Chat failed: {str(e)}") from e | |