Spaces:
Running
Running
File size: 4,095 Bytes
c4233b7 6e1d29c c4233b7 6e1d29c c4233b7 6e1d29c c4233b7 6e1d29c c4233b7 6e1d29c c4233b7 6e1d29c c4233b7 6e1d29c c4233b7 6e1d29c c4233b7 6e1d29c c4233b7 6e1d29c c4233b7 6e1d29c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
from typing import List, Tuple, Dict, Any, Optional
from src.embeddings import embed_texts
from src.vectorstore import query_by_embedding
from src.openai_client import get_client
from src.config import CHAT_MODEL, TOP_K
# ---------------- Query Rewrite (Domain-agnostic) ----------------
def rewrite_queries(question: str, n: int = 4) -> List[str]:
"""
Creates multiple semantic variants of the user query to improve recall.
Works for any domain (medical/legal/finance/etc.) without hardcoded synonyms.
"""
client = get_client()
prompt = f"""
You help a RAG system retrieve relevant document chunks.
Rewrite the user query into {n} short alternative search queries that capture the same intent.
Include abbreviations, synonyms, and likely wording that might appear in documents.
Return ONLY the queries, one per line. No numbering, no extra text.
User query: {question}
"""
resp = client.responses.create(model=CHAT_MODEL, input=prompt)
lines = [ln.strip() for ln in resp.output_text.splitlines() if ln.strip()]
# Always include original first + dedupe
out = [question] + lines
seen = set()
final = []
for q in out:
k = q.lower()
if k not in seen:
seen.add(k)
final.append(q)
return final[: n + 1]
# ---------------- Clarification (Domain-agnostic) ----------------
def clarification_question(user_query: str) -> Optional[str]:
"""
If the query is too short/ambiguous, returns a clarification question.
Otherwise returns None.
"""
client = get_client()
prompt = f"""
Decide if this user query is too short or ambiguous for document retrieval.
If clarification is needed, return ONE short clarification question.
If not needed, return exactly: NO
User query: {user_query}
"""
resp = client.responses.create(model=CHAT_MODEL, input=prompt)
out = resp.output_text.strip()
if out.upper() == "NO":
return None
return out
# ---------------- Multi-query Retrieval + Dedupe ----------------
def retrieve_context(question: str, top_k: int = TOP_K) -> Tuple[str, List[str]]:
"""
Retrieves context using multi-query rewrite to improve semantic matches.
Returns (context_string, citations_list).
"""
queries = rewrite_queries(question, n=4)
all_docs: List[str] = []
all_metas: List[Dict[str, Any]] = []
for q in queries:
q_vec = embed_texts([q])[0]
docs, metas = query_by_embedding(q_vec, top_k=top_k)
all_docs.extend(docs)
all_metas.extend(metas)
# Deduplicate by (file, page, snippet)
seen = set()
final_docs: List[str] = []
final_metas: List[Dict[str, Any]] = []
for d, m in zip(all_docs, all_metas):
fp = (m.get("file"), m.get("page"), (d[:160] if d else ""))
if fp not in seen:
seen.add(fp)
final_docs.append(d)
final_metas.append(m)
final_docs = final_docs[:top_k]
final_metas = final_metas[:top_k]
context_blocks = []
citations = []
for i, (doc, meta) in enumerate(zip(final_docs, final_metas), start=1):
citations.append(f"[{i}] {meta.get('file')} (page {meta.get('page')})")
context_blocks.append(
f"Source {i}: {meta.get('file')} (page {meta.get('page')})\n{doc}"
)
return "\n\n---\n\n".join(context_blocks), citations
def answer_question(question: str) -> Tuple[str, List[str]]:
"""
Answers grounded in retrieved sources.
"""
context, citations = retrieve_context(question, top_k=TOP_K)
prompt = f"""
You are a document assistant.
Answer using the SOURCES below.
If the answer is not in the sources, say: "I don't know from the uploaded documents."
SOURCES:
{context}
QUESTION:
{question}
Rules:
- Be helpful and concise.
- It's okay to paraphrase, but do not invent facts.
- At the end, list: Sources used: [numbers only]
Return:
1) Answer
2) Sources used: [..]
"""
client = get_client()
resp = client.responses.create(model=CHAT_MODEL, input=prompt)
return resp.output_text.strip(), citations
|