Spaces:
Running
Running
| from typing import List, Tuple, Dict, Any, Optional | |
| from src.embeddings import embed_texts | |
| from src.vectorstore import query_by_embedding | |
| from src.openai_client import get_client | |
| from src.config import CHAT_MODEL, TOP_K | |
| # ---------------- Query Rewrite (Domain-agnostic) ---------------- | |
| def rewrite_queries(question: str, n: int = 4) -> List[str]: | |
| """ | |
| Creates multiple semantic variants of the user query to improve recall. | |
| Works for any domain (medical/legal/finance/etc.) without hardcoded synonyms. | |
| """ | |
| client = get_client() | |
| prompt = f""" | |
| You help a RAG system retrieve relevant document chunks. | |
| Rewrite the user query into {n} short alternative search queries that capture the same intent. | |
| Include abbreviations, synonyms, and likely wording that might appear in documents. | |
| Return ONLY the queries, one per line. No numbering, no extra text. | |
| User query: {question} | |
| """ | |
| resp = client.responses.create(model=CHAT_MODEL, input=prompt) | |
| lines = [ln.strip() for ln in resp.output_text.splitlines() if ln.strip()] | |
| # Always include original first + dedupe | |
| out = [question] + lines | |
| seen = set() | |
| final = [] | |
| for q in out: | |
| k = q.lower() | |
| if k not in seen: | |
| seen.add(k) | |
| final.append(q) | |
| return final[: n + 1] | |
| # ---------------- Clarification (Domain-agnostic) ---------------- | |
| def clarification_question(user_query: str) -> Optional[str]: | |
| """ | |
| If the query is too short/ambiguous, returns a clarification question. | |
| Otherwise returns None. | |
| """ | |
| client = get_client() | |
| prompt = f""" | |
| Decide if this user query is too short or ambiguous for document retrieval. | |
| If clarification is needed, return ONE short clarification question. | |
| If not needed, return exactly: NO | |
| User query: {user_query} | |
| """ | |
| resp = client.responses.create(model=CHAT_MODEL, input=prompt) | |
| out = resp.output_text.strip() | |
| if out.upper() == "NO": | |
| return None | |
| return out | |
| # ---------------- Multi-query Retrieval + Dedupe ---------------- | |
| def retrieve_context(question: str, top_k: int = TOP_K) -> Tuple[str, List[str]]: | |
| """ | |
| Retrieves context using multi-query rewrite to improve semantic matches. | |
| Returns (context_string, citations_list). | |
| """ | |
| queries = rewrite_queries(question, n=4) | |
| all_docs: List[str] = [] | |
| all_metas: List[Dict[str, Any]] = [] | |
| for q in queries: | |
| q_vec = embed_texts([q])[0] | |
| docs, metas = query_by_embedding(q_vec, top_k=top_k) | |
| all_docs.extend(docs) | |
| all_metas.extend(metas) | |
| # Deduplicate by (file, page, snippet) | |
| seen = set() | |
| final_docs: List[str] = [] | |
| final_metas: List[Dict[str, Any]] = [] | |
| for d, m in zip(all_docs, all_metas): | |
| fp = (m.get("file"), m.get("page"), (d[:160] if d else "")) | |
| if fp not in seen: | |
| seen.add(fp) | |
| final_docs.append(d) | |
| final_metas.append(m) | |
| final_docs = final_docs[:top_k] | |
| final_metas = final_metas[:top_k] | |
| context_blocks = [] | |
| citations = [] | |
| for i, (doc, meta) in enumerate(zip(final_docs, final_metas), start=1): | |
| citations.append(f"[{i}] {meta.get('file')} (page {meta.get('page')})") | |
| context_blocks.append( | |
| f"Source {i}: {meta.get('file')} (page {meta.get('page')})\n{doc}" | |
| ) | |
| return "\n\n---\n\n".join(context_blocks), citations | |
| def answer_question(question: str) -> Tuple[str, List[str]]: | |
| """ | |
| Answers grounded in retrieved sources. | |
| """ | |
| context, citations = retrieve_context(question, top_k=TOP_K) | |
| prompt = f""" | |
| You are a document assistant. | |
| Answer using the SOURCES below. | |
| If the answer is not in the sources, say: "I don't know from the uploaded documents." | |
| SOURCES: | |
| {context} | |
| QUESTION: | |
| {question} | |
| Rules: | |
| - Be helpful and concise. | |
| - It's okay to paraphrase, but do not invent facts. | |
| - At the end, list: Sources used: [numbers only] | |
| Return: | |
| 1) Answer | |
| 2) Sources used: [..] | |
| """ | |
| client = get_client() | |
| resp = client.responses.create(model=CHAT_MODEL, input=prompt) | |
| return resp.output_text.strip(), citations | |