File size: 4,095 Bytes
c4233b7
6e1d29c
c4233b7
 
 
 
 
 
6e1d29c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4233b7
6e1d29c
 
c4233b7
6e1d29c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4233b7
 
 
 
6e1d29c
c4233b7
 
 
 
 
 
 
6e1d29c
c4233b7
6e1d29c
 
 
c4233b7
 
 
 
6e1d29c
c4233b7
 
 
 
 
 
 
 
6e1d29c
 
 
 
 
c4233b7
6e1d29c
 
c4233b7
 
 
 
 
6e1d29c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137

from typing import List, Tuple, Dict, Any, Optional
from src.embeddings import embed_texts
from src.vectorstore import query_by_embedding
from src.openai_client import get_client
from src.config import CHAT_MODEL, TOP_K


# ---------------- Query Rewrite (Domain-agnostic) ----------------
def rewrite_queries(question: str, n: int = 4) -> List[str]:
    """
    Creates multiple semantic variants of the user query to improve recall.
    Works for any domain (medical/legal/finance/etc.) without hardcoded synonyms.
    """
    client = get_client()
    prompt = f"""
You help a RAG system retrieve relevant document chunks.

Rewrite the user query into {n} short alternative search queries that capture the same intent.
Include abbreviations, synonyms, and likely wording that might appear in documents.
Return ONLY the queries, one per line. No numbering, no extra text.

User query: {question}
"""
    resp = client.responses.create(model=CHAT_MODEL, input=prompt)
    lines = [ln.strip() for ln in resp.output_text.splitlines() if ln.strip()]

    # Always include original first + dedupe
    out = [question] + lines
    seen = set()
    final = []
    for q in out:
        k = q.lower()
        if k not in seen:
            seen.add(k)
            final.append(q)

    return final[: n + 1]


# ---------------- Clarification (Domain-agnostic) ----------------
def clarification_question(user_query: str) -> Optional[str]:
    """
    If the query is too short/ambiguous, returns a clarification question.
    Otherwise returns None.
    """
    client = get_client()
    prompt = f"""
Decide if this user query is too short or ambiguous for document retrieval.
If clarification is needed, return ONE short clarification question.
If not needed, return exactly: NO

User query: {user_query}
"""
    resp = client.responses.create(model=CHAT_MODEL, input=prompt)
    out = resp.output_text.strip()
    if out.upper() == "NO":
        return None
    return out


# ---------------- Multi-query Retrieval + Dedupe ----------------
def retrieve_context(question: str, top_k: int = TOP_K) -> Tuple[str, List[str]]:
    """
    Retrieves context using multi-query rewrite to improve semantic matches.
    Returns (context_string, citations_list).
    """
    queries = rewrite_queries(question, n=4)

    all_docs: List[str] = []
    all_metas: List[Dict[str, Any]] = []

    for q in queries:
        q_vec = embed_texts([q])[0]
        docs, metas = query_by_embedding(q_vec, top_k=top_k)
        all_docs.extend(docs)
        all_metas.extend(metas)

    # Deduplicate by (file, page, snippet)
    seen = set()
    final_docs: List[str] = []
    final_metas: List[Dict[str, Any]] = []

    for d, m in zip(all_docs, all_metas):
        fp = (m.get("file"), m.get("page"), (d[:160] if d else ""))
        if fp not in seen:
            seen.add(fp)
            final_docs.append(d)
            final_metas.append(m)

    final_docs = final_docs[:top_k]
    final_metas = final_metas[:top_k]

    context_blocks = []
    citations = []

    for i, (doc, meta) in enumerate(zip(final_docs, final_metas), start=1):
        citations.append(f"[{i}] {meta.get('file')} (page {meta.get('page')})")
        context_blocks.append(
            f"Source {i}: {meta.get('file')} (page {meta.get('page')})\n{doc}"
        )

    return "\n\n---\n\n".join(context_blocks), citations


def answer_question(question: str) -> Tuple[str, List[str]]:
    """
    Answers grounded in retrieved sources.
    """
    context, citations = retrieve_context(question, top_k=TOP_K)

    prompt = f"""
You are a document assistant.
Answer using the SOURCES below.
If the answer is not in the sources, say: "I don't know from the uploaded documents."

SOURCES:
{context}

QUESTION:
{question}

Rules:
- Be helpful and concise.
- It's okay to paraphrase, but do not invent facts.
- At the end, list: Sources used: [numbers only]

Return:
1) Answer
2) Sources used: [..]
"""

    client = get_client()
    resp = client.responses.create(model=CHAT_MODEL, input=prompt)
    return resp.output_text.strip(), citations