File size: 5,541 Bytes
ef879c5
 
b830e05
 
ef879c5
 
303ff5a
ef879c5
 
 
 
 
 
 
 
 
 
e66ee1b
 
b830e05
e66ee1b
ef879c5
 
 
 
 
 
 
 
 
 
 
 
 
b830e05
ef879c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303ff5a
ef879c5
 
 
 
 
 
 
303ff5a
ef879c5
 
 
303ff5a
 
ef879c5
 
303ff5a
 
 
 
ef879c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e66ee1b
 
 
 
 
ef879c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e66ee1b
 
 
 
 
 
 
 
ef879c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
"""API endpoints for RAG-powered chat"""

from typing import List, Optional

from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from sqlalchemy import Float, Integer, String, bindparam, text
from sqlalchemy.ext.asyncio import AsyncSession

from backend.database import get_db
from backend.services.openai_service import OpenAIService

router = APIRouter(prefix="/api/chat", tags=["chat"])


class ChatRequest(BaseModel):
    question: str
    top_k: int = 7  # Increased for better coverage with 59 chunks
    similarity_threshold: float = 0.55  # Balanced threshold for quality results
    system_prompt: Optional[str] = None
    temperature: float = 0.6  # Lowered for more accurate responses


class Citation(BaseModel):
    chunk_id: int
    doc_id: int
    document_title: str
    text: str
    similarity_score: float


class ChatResponse(BaseModel):
    question: str
    answer: str
    citations: List[Citation]
    total_citations: int


@router.post("/", response_model=ChatResponse)
async def chat_with_rag(request: ChatRequest, db: AsyncSession = Depends(get_db)):
    """
    Answer questions using RAG (Retrieval-Augmented Generation)

    Process:
    1. Generate embedding for question
    2. Search for similar chunks
    3. Construct context from top results
    4. Generate answer using LLM with context
    5. Return answer with citations

    Args:
        request: Chat question and parameters
        db: Database session

    Returns:
        ChatResponse with answer and citations
    """

    try:
        # Initialize OpenAI service
        openai_service = OpenAIService()

        # 1. Generate embedding for question
        query_embedding = await openai_service.create_embedding(request.question)

        # Convert embedding to PostgreSQL array format
        embedding_str = "[" + ",".join(map(str, query_embedding)) + "]"

        # 2. Search for similar chunks
        # Use bindparam for proper parameter binding with asyncpg
        query_sql = text(
            """
            SELECT
                c.id as chunk_id,
                c.doc_id,
                c.text,
                d.title as document_title,
                1 - (e.embedding <=> CAST(:query_embedding AS vector)) as similarity_score
            FROM chunks c
            JOIN embeddings e ON c.id = e.chunk_id
            JOIN documents d ON c.doc_id = d.id
            WHERE 1 - (e.embedding <=> CAST(:query_embedding AS vector)) >= :threshold
            ORDER BY e.embedding <=> CAST(:query_embedding AS vector)
            LIMIT :top_k
        """
        ).bindparams(
            bindparam("query_embedding", type_=String),
            bindparam("threshold", type_=Float),
            bindparam("top_k", type_=Integer),
        )

        result = await db.execute(
            query_sql,
            {
                "query_embedding": embedding_str,
                "threshold": request.similarity_threshold,
                "top_k": request.top_k,
            },
        )

        rows = result.fetchall()

        if not rows:
            return ChatResponse(
                question=request.question,
                answer=(
                    "抱歉,我在文檔中找不到足夠的相關資訊來回答這個問題。\n\n"
                    "建議:\n"
                    "1. 請嘗試更具體地描述您的問題\n"
                    "2. 使用職涯相關的關鍵詞(例如:職涯發展、生涯規劃、諮詢技巧等)\n"
                    "3. 確認相關文檔已上傳"
                ),
                citations=[],
                total_citations=0,
            )

        # 3. Construct context from top results
        context_parts = []
        citations = []

        for idx, row in enumerate(rows):
            context_parts.append(f"[{idx + 1}] {row.text}")

            citations.append(
                Citation(
                    chunk_id=row.chunk_id,
                    doc_id=row.doc_id,
                    document_title=row.document_title,
                    text=row.text,
                    similarity_score=float(row.similarity_score),
                )
            )

        context = "\n\n".join(context_parts)

        # 4. Generate answer using LLM with context
        system_prompt = request.system_prompt or (
            "你是一位專業的職涯諮詢師助理。請根據提供的文檔內容回答問題。\n\n"
            "回答時請:\n"
            "1. 以提供的文檔內容為主要依據,並使用引用編號 [1], [2] 等標示資訊來源\n"
            "2. 保持專業、客觀、同理的語氣\n"
            "3. 如果文檔中沒有相關資訊,請明確說明,不要猜測或編造\n"
            "4. 適當引用文檔中的關鍵概念、理論和實務做法\n"
            "5. 考慮個別差異和多元觀點\n"
            "6. 使用與問題相同的語言回答(繁體中文或英文)"
        )

        answer = await openai_service.chat_completion_with_context(
            question=request.question,
            context=context,
            system_prompt=system_prompt,
            temperature=request.temperature,
        )

        # 5. Return answer with citations
        return ChatResponse(
            question=request.question,
            answer=answer,
            citations=citations,
            total_citations=len(citations),
        )

    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Chat failed: {str(e)}") from e