from typing import List, Tuple, Dict, Any, Optional from src.embeddings import embed_texts from src.vectorstore import query_by_embedding from src.openai_client import get_client from src.config import CHAT_MODEL, TOP_K # ---------------- Query Rewrite (Domain-agnostic) ---------------- def rewrite_queries(question: str, n: int = 4) -> List[str]: """ Creates multiple semantic variants of the user query to improve recall. Works for any domain (medical/legal/finance/etc.) without hardcoded synonyms. """ client = get_client() prompt = f""" You help a RAG system retrieve relevant document chunks. Rewrite the user query into {n} short alternative search queries that capture the same intent. Include abbreviations, synonyms, and likely wording that might appear in documents. Return ONLY the queries, one per line. No numbering, no extra text. User query: {question} """ resp = client.responses.create(model=CHAT_MODEL, input=prompt) lines = [ln.strip() for ln in resp.output_text.splitlines() if ln.strip()] # Always include original first + dedupe out = [question] + lines seen = set() final = [] for q in out: k = q.lower() if k not in seen: seen.add(k) final.append(q) return final[: n + 1] # ---------------- Clarification (Domain-agnostic) ---------------- def clarification_question(user_query: str) -> Optional[str]: """ If the query is too short/ambiguous, returns a clarification question. Otherwise returns None. """ client = get_client() prompt = f""" Decide if this user query is too short or ambiguous for document retrieval. If clarification is needed, return ONE short clarification question. If not needed, return exactly: NO User query: {user_query} """ resp = client.responses.create(model=CHAT_MODEL, input=prompt) out = resp.output_text.strip() if out.upper() == "NO": return None return out # ---------------- Multi-query Retrieval + Dedupe ---------------- def retrieve_context(question: str, top_k: int = TOP_K) -> Tuple[str, List[str]]: """ Retrieves context using multi-query rewrite to improve semantic matches. Returns (context_string, citations_list). """ queries = rewrite_queries(question, n=4) all_docs: List[str] = [] all_metas: List[Dict[str, Any]] = [] for q in queries: q_vec = embed_texts([q])[0] docs, metas = query_by_embedding(q_vec, top_k=top_k) all_docs.extend(docs) all_metas.extend(metas) # Deduplicate by (file, page, snippet) seen = set() final_docs: List[str] = [] final_metas: List[Dict[str, Any]] = [] for d, m in zip(all_docs, all_metas): fp = (m.get("file"), m.get("page"), (d[:160] if d else "")) if fp not in seen: seen.add(fp) final_docs.append(d) final_metas.append(m) final_docs = final_docs[:top_k] final_metas = final_metas[:top_k] context_blocks = [] citations = [] for i, (doc, meta) in enumerate(zip(final_docs, final_metas), start=1): citations.append(f"[{i}] {meta.get('file')} (page {meta.get('page')})") context_blocks.append( f"Source {i}: {meta.get('file')} (page {meta.get('page')})\n{doc}" ) return "\n\n---\n\n".join(context_blocks), citations def answer_question(question: str) -> Tuple[str, List[str]]: """ Answers grounded in retrieved sources. """ context, citations = retrieve_context(question, top_k=TOP_K) prompt = f""" You are a document assistant. Answer using the SOURCES below. If the answer is not in the sources, say: "I don't know from the uploaded documents." SOURCES: {context} QUESTION: {question} Rules: - Be helpful and concise. - It's okay to paraphrase, but do not invent facts. - At the end, list: Sources used: [numbers only] Return: 1) Answer 2) Sources used: [..] """ client = get_client() resp = client.responses.create(model=CHAT_MODEL, input=prompt) return resp.output_text.strip(), citations