Spaces:
Runtime error
Runtime error
| import os | |
| import uuid | |
| import json | |
| from typing import List, Tuple, Dict, Any, Optional | |
| import chromadb | |
| from chromadb.config import Settings | |
| from openai import OpenAI | |
| import gradio as gr | |
| from pypdf import PdfReader | |
| # Cross-encoder (Hugging Face / sentence-transformers) | |
| # pip install sentence-transformers torch | |
| from sentence_transformers import CrossEncoder | |
| # ========================= | |
| # Chroma Client (Persistent) | |
| # ========================= | |
| chroma_client = chromadb.PersistentClient( | |
| path="chroma_db", | |
| settings=Settings(anonymized_telemetry=False), | |
| ) | |
| collection = chroma_client.get_or_create_collection( | |
| name="rag_docs", | |
| metadata={"hnsw:space": "cosine"}, | |
| ) | |
| # ========================= | |
| # Cross-Encoder (lazy global) | |
| # ========================= | |
| _CROSS_ENCODER: Optional[CrossEncoder] = None | |
| CROSS_ENCODER_MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2" | |
| def get_cross_encoder() -> CrossEncoder: | |
| global _CROSS_ENCODER | |
| if _CROSS_ENCODER is None: | |
| _CROSS_ENCODER = CrossEncoder(CROSS_ENCODER_MODEL_NAME) | |
| return _CROSS_ENCODER | |
| # ========================= | |
| # Helper Functions | |
| # ========================= | |
| def get_openai_client(api_key: str) -> OpenAI: | |
| if not api_key or not api_key.strip(): | |
| raise ValueError("OpenAI API key is missing.") | |
| return OpenAI(api_key=api_key.strip()) | |
| def extract_text_from_file(file_path: str) -> str: | |
| ext = os.path.splitext(file_path)[1].lower() | |
| if ext in [".txt", ".md"]: | |
| with open(file_path, "r", encoding="utf-8", errors="ignore") as f: | |
| return f.read() | |
| if ext == ".pdf": | |
| text = [] | |
| reader = PdfReader(file_path) | |
| for page in reader.pages: | |
| page_text = page.extract_text() | |
| if page_text: | |
| text.append(page_text) | |
| return "\n".join(text) | |
| with open(file_path, "r", encoding="utf-8", errors="ignore") as f: | |
| return f.read() | |
| def chunk_text(text: str, chunk_size: int = 800, overlap: int = 200) -> List[str]: | |
| text = text.replace("\r\n", "\n").replace("\r", "\n") | |
| chunks = [] | |
| start = 0 | |
| while start < len(text): | |
| end = start + chunk_size | |
| chunks.append(text[start:end]) | |
| start += chunk_size - overlap | |
| return chunks | |
| def embed_texts(texts: List[str], api_key: str) -> List[List[float]]: | |
| if not texts: | |
| return [] | |
| client = get_openai_client(api_key) | |
| resp = client.embeddings.create( | |
| model="text-embedding-3-small", | |
| input=texts, | |
| ) | |
| return [d.embedding for d in resp.data] | |
| def add_documents_to_chroma(file_paths: List[str], api_key: str) -> str: | |
| if not file_paths: | |
| return "⚠️ No files were provided." | |
| total_chunks = 0 | |
| for file_path in file_paths: | |
| file_name = os.path.basename(file_path) | |
| raw_text = extract_text_from_file(file_path) | |
| if not raw_text.strip(): | |
| continue | |
| chunks = chunk_text(raw_text) | |
| embeddings = embed_texts(chunks, api_key) | |
| ids = [f"{file_name}-{uuid.uuid4()}" for _ in chunks] | |
| metadatas = [{"source": file_name} for _ in chunks] | |
| collection.add( | |
| ids=ids, | |
| documents=chunks, | |
| metadatas=metadatas, | |
| embeddings=embeddings, | |
| ) | |
| total_chunks += len(chunks) | |
| count = collection.count() | |
| return ( | |
| f"✅ Indexed {len(file_paths)} file(s) into Chroma with {total_chunks} chunks. " | |
| f"Collection now has {count} vectors." | |
| ) | |
| # ========================= | |
| # Query Expansion | |
| # ========================= | |
| def query_expansion(user_query: str, api_key: str) -> List[str]: | |
| user_query = (user_query or "").strip() | |
| if not user_query: | |
| return [] | |
| client = get_openai_client(api_key) | |
| system_prompt = ( | |
| "You are an expert in information retrieval systems, particularly skilled in enhancing queries " | |
| "for document search efficiency." | |
| ) | |
| user_prompt = f""" | |
| Perform query expansion on the received question by considering alternative phrasings or synonyms commonly used in document retrieval contexts. | |
| If there are multiple ways to phrase the user's question or common synonyms for key terms, provide several reworded versions. | |
| If there are acronyms or words you are not familiar with, do not try to rephrase them. | |
| Return at least 3 versions of the question. | |
| Return ONLY valid JSON with this exact shape: | |
| {{ | |
| "expanded": ["q1", "q2", "q3"] | |
| }} | |
| Question: | |
| {user_query} | |
| """.strip() | |
| completion = client.chat.completions.create( | |
| model="gpt-4.1-mini", | |
| temperature=0.2, | |
| response_format={"type": "json_object"}, | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt}, | |
| ], | |
| ) | |
| raw = completion.choices[0].message.content | |
| try: | |
| data = json.loads(raw) | |
| expanded = data.get("expanded", []) | |
| except json.JSONDecodeError: | |
| expanded = [] | |
| expanded = [q.strip() for q in expanded if isinstance(q, str) and q.strip()] | |
| while len(expanded) < 3: | |
| expanded.append(user_query) | |
| # include original as first option | |
| if expanded and expanded[0] != user_query: | |
| expanded = [user_query] + expanded | |
| # De-dupe preserving order | |
| seen = set() | |
| out = [] | |
| for q in expanded: | |
| if q not in seen: | |
| seen.add(q) | |
| out.append(q) | |
| return out | |
| def format_expansions_md(expanded: List[str]) -> str: | |
| if not expanded: | |
| return "*(No expansions yet — type a question and press Enter.)*" | |
| lines = [f"{i+1}. {q}" for i, q in enumerate(expanded)] | |
| return "### 🧠 Expanded Queries\n\n" + "\n".join(lines) | |
| # ========================= | |
| # LLM Self-Evaluation Helper | |
| # ========================= | |
| def evaluate_answer(question: str, context: str, answer: str, api_key: str) -> dict: | |
| client = get_openai_client(api_key) | |
| system_prompt = ( | |
| "You are an impartial evaluator for a Retrieval-Augmented Generation (RAG) system. " | |
| "You will receive: (1) the user query, (2) the retrieved context, and (3) the model's answer. " | |
| "You must evaluate the answer on five metrics, each scored from 1 (very poor) to 5 (excellent):\n" | |
| "- Groundedness: Is the answer supported by the retrieved CONTEXT (not outside knowledge)?\n" | |
| "- Relevance: Does the answer address the USER QUERY directly and appropriately?\n" | |
| "- Faithfulness: Are the statements logically valid and consistent with the context (no contradictions)?\n" | |
| "- Context Precision: Does the answer avoid including irrelevant details from the context?\n" | |
| "- Context Recall: Does the answer capture all IMPORTANT information from the context needed to answer well?\n\n" | |
| "Return ONLY a single JSON object with this exact structure:\n" | |
| "{\n" | |
| ' "query": string,\n' | |
| ' "response": string,\n' | |
| ' "groundedness_evaluation": {"score": int, "justification": string},\n' | |
| ' "relevance_evaluation": {"score": int, "justification": string},\n' | |
| ' "faithfulness_evaluation": {"score": int, "justification": string},\n' | |
| ' "context_precision_evaluation": {"score": int, "justification": string},\n' | |
| ' "context_recall_evaluation": {"score": int, "justification": string}\n' | |
| "}" | |
| ) | |
| user_prompt = ( | |
| f"USER QUERY:\n{question}\n\n" | |
| f"RETRIEVED CONTEXT:\n{context}\n\n" | |
| f"MODEL ANSWER:\n{answer}" | |
| ) | |
| completion = client.chat.completions.create( | |
| model="gpt-4.1-mini", | |
| temperature=0.0, | |
| response_format={"type": "json_object"}, | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt}, | |
| ], | |
| ) | |
| raw = completion.choices[0].message.content | |
| try: | |
| return json.loads(raw) | |
| except json.JSONDecodeError: | |
| return { | |
| "query": question, | |
| "response": answer, | |
| "groundedness_evaluation": {"score": None, "justification": "Failed to parse JSON evaluation."}, | |
| "relevance_evaluation": {"score": None, "justification": raw}, | |
| "faithfulness_evaluation": {"score": None, "justification": ""}, | |
| "context_precision_evaluation": {"score": None, "justification": ""}, | |
| "context_recall_evaluation": {"score": None, "justification": ""}, | |
| } | |
| # ========================================================= | |
| # REQUIRED: Chroma Retrieval + Cross-Encoder Rerank + Prompt | |
| # ========================================================= | |
| def retrieve_from_chroma(query: str, top_k: int, api_key: str) -> List[Dict[str, Any]]: | |
| """ | |
| Retrieve top_k passages from Chroma using embeddings. | |
| Preserves ids + metadatas + distances + documents. | |
| Returns list[dict] with keys: | |
| - id: str | |
| - text: str | |
| - metadata: dict | |
| - distance: float|None | |
| """ | |
| query = (query or "").strip() | |
| if not query: | |
| return [] | |
| if collection.count() == 0: | |
| return [] | |
| q_emb = embed_texts([query], api_key)[0] | |
| results = collection.query( | |
| query_embeddings=[q_emb], | |
| n_results=top_k, | |
| ) | |
| ids = results.get("ids", [[]])[0] or [] | |
| docs = results.get("documents", [[]])[0] or [] | |
| metas = results.get("metadatas", [[]])[0] or [] | |
| dists = results.get("distances", [[]])[0] if "distances" in results else [None] * len(docs) | |
| out = [] | |
| for i in range(min(len(docs), len(ids), len(metas))): | |
| out.append({ | |
| "id": ids[i], | |
| "text": docs[i], | |
| "metadata": metas[i] or {}, | |
| "distance": dists[i] if i < len(dists) else None, | |
| }) | |
| return out | |
| def cross_encoder_rerank(query: str, docs: List[Dict[str, Any]], top_n: int) -> List[Dict[str, Any]]: | |
| """ | |
| Rerank retrieved passages with a HF cross-encoder: | |
| model = "cross-encoder/ms-marco-MiniLM-L-6-v2" | |
| Inputs: | |
| - query: str | |
| - docs: list of dicts from retrieve_from_chroma or merged retrieval | |
| - top_n: int | |
| Returns: list of docs with added field: | |
| - score: float (higher is better) | |
| """ | |
| query = (query or "").strip() | |
| if not query or not docs: | |
| return [] | |
| model = get_cross_encoder() | |
| pairs = [(query, d.get("text", "")) for d in docs] | |
| scores = model.predict(pairs) | |
| reranked = [] | |
| for d, s in zip(docs, scores): | |
| dd = dict(d) | |
| dd["score"] = float(s) | |
| reranked.append(dd) | |
| reranked.sort(key=lambda x: x.get("score", float("-inf")), reverse=True) | |
| return reranked[:top_n] | |
| def build_prompt(query: str, reranked_docs: List[Dict[str, Any]]) -> Tuple[str, str]: | |
| """ | |
| Build the final context string and the LLM prompt. | |
| Returns: | |
| - context: str (the final context string) | |
| - prompt: str (full prompt for the LLM) | |
| """ | |
| parts = [] | |
| for d in reranked_docs: | |
| md = d.get("metadata", {}) or {} | |
| source = md.get("source", "unknown") | |
| page = md.get("page", md.get("page_number", md.get("pageno", ""))) | |
| header = f"Source: {source}" | |
| if page != "" and page is not None: | |
| header += f" | Page: {page}" | |
| parts.append(f"{header}\n{d.get('text','')}".strip()) | |
| context = "\n\n---\n\n".join(parts).strip() | |
| prompt = ( | |
| "You are a helpful assistant that answers questions ONLY using the provided document context. " | |
| "If the context does not contain the answer, say you do not know.\n\n" | |
| f"Context from documents:\n\n{context}\n\n" | |
| f"Question: {query}\n\n" | |
| "Answer based only on the context above." | |
| ) | |
| return context, prompt | |
| # ========================= | |
| # Existing Multi-Query RAG (unchanged behavior) | |
| # ========================= | |
| def _merge_docs_by_id(doc_lists: List[List[Dict[str, Any]]]) -> List[Dict[str, Any]]: | |
| """ | |
| Merge/dedupe docs (dicts) by Chroma chunk id. Keeps the best (lowest) distance if present. | |
| """ | |
| merged: Dict[str, Dict[str, Any]] = {} | |
| for docs in doc_lists: | |
| for d in docs: | |
| cid = d.get("id") | |
| if not cid: | |
| continue | |
| if cid not in merged: | |
| merged[cid] = d | |
| else: | |
| # keep best distance if both have it | |
| old_dist = merged[cid].get("distance") | |
| new_dist = d.get("distance") | |
| if old_dist is not None and new_dist is not None and new_dist < old_dist: | |
| merged[cid] = d | |
| return list(merged.values()) | |
| def query_rag_multi(selected_queries: List[str], api_key: str) -> str: | |
| selected_queries = [q.strip() for q in (selected_queries or []) if isinstance(q, str) and q.strip()] | |
| if not selected_queries: | |
| return "⚠️ Please select at least one expanded query." | |
| if collection.count() == 0: | |
| return "⚠️ No documents in the database yet. Upload and index some documents first." | |
| # Your prior behavior: embed each selected query, retrieve 5 each, merge, take top 5 overall. | |
| # (We keep this as-is.) | |
| q_embs = embed_texts(selected_queries, api_key) | |
| results = collection.query( | |
| query_embeddings=q_embs, | |
| n_results=5, | |
| ) | |
| # Convert multi-query results to docs | |
| all_ids = results.get("ids", []) | |
| all_docs = results.get("documents", []) | |
| all_metas = results.get("metadatas", []) | |
| all_dist = results.get("distances", None) | |
| doc_lists: List[List[Dict[str, Any]]] = [] | |
| for qi in range(len(all_docs)): | |
| ids_i = all_ids[qi] if qi < len(all_ids) else [] | |
| docs_i = all_docs[qi] if qi < len(all_docs) else [] | |
| metas_i = all_metas[qi] if qi < len(all_metas) else [] | |
| dist_i = all_dist[qi] if isinstance(all_dist, list) and qi < len(all_dist) else [None] * len(docs_i) | |
| out_i = [] | |
| for cid, doc, meta, dist in zip(ids_i, docs_i, metas_i, dist_i): | |
| out_i.append({"id": cid, "text": doc, "metadata": meta or {}, "distance": dist}) | |
| doc_lists.append(out_i) | |
| merged = _merge_docs_by_id(doc_lists) | |
| if not merged: | |
| return "I couldn't find any relevant context in the indexed documents." | |
| # best-first by distance if available | |
| merged.sort(key=lambda d: (d.get("distance") is None, d.get("distance", 1e9))) | |
| top = merged[:5] | |
| context_parts = [] | |
| for d in top: | |
| md = d.get("metadata", {}) or {} | |
| context_parts.append(f"Source: {md.get('source','unknown')}\n{d.get('text','')}") | |
| context = "\n\n---\n\n".join(context_parts) | |
| client = get_openai_client(api_key) | |
| system_prompt = ( | |
| "You are a helpful assistant that answers questions ONLY using the provided document context. " | |
| "If the context does not contain the answer, say you do not know." | |
| ) | |
| user_prompt = ( | |
| f"Context from documents:\n\n{context}\n\n" | |
| f"Selected expanded queries:\n- " + "\n- ".join(selected_queries) + "\n\n" | |
| "Answer based only on the context above." | |
| ) | |
| completion = client.chat.completions.create( | |
| model="gpt-4.1-mini", | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt}, | |
| ], | |
| temperature=0.1, | |
| ) | |
| response_text = completion.choices[0].message.content.strip() | |
| try: | |
| eval_dict = evaluate_answer( | |
| question=" | ".join(selected_queries), | |
| context=context, | |
| answer=response_text, | |
| api_key=api_key, | |
| ) | |
| log_record = { | |
| "query": eval_dict.get("query"), | |
| "response": eval_dict.get("response"), | |
| "groundedness_evaluation": eval_dict.get("groundedness_evaluation"), | |
| "relevance_evaluation": eval_dict.get("relevance_evaluation"), | |
| "faithfulness_evaluation": eval_dict.get("faithfulness_evaluation"), | |
| "context_precision_evaluation": eval_dict.get("context_precision_evaluation"), | |
| "context_recall_evaluation": eval_dict.get("context_recall_evaluation"), | |
| } | |
| return ( | |
| f"### 💬 Answer\n\n{response_text}\n\n" | |
| f"---\n\n" | |
| f"### 🔍 Self-evaluation (1–5)\n\n" | |
| f"```json\n{json.dumps(log_record, indent=2)}\n```" | |
| ) | |
| except Exception as e: | |
| return ( | |
| f"### 💬 Answer\n\n{response_text}\n\n" | |
| f"---\n\n" | |
| f"⚠️ Self-evaluation failed: {e}" | |
| ) | |
| # ========================= | |
| # Cross-Encode Stage UI Helpers | |
| # ========================= | |
| def format_rerank_results_md(query: str, reranked: List[Dict[str, Any]], top_n: int) -> str: | |
| if not reranked: | |
| return "*(No reranked results to display.)*" | |
| lines = [] | |
| lines.append(f"### 🎯 Cross-Encoder Rerank Results (top {top_n})") | |
| lines.append("") | |
| lines.append("| Rank | Score | Source | Page | Snippet |") | |
| lines.append("|---:|---:|---|---:|---|") | |
| for i, d in enumerate(reranked, start=1): | |
| md = d.get("metadata", {}) or {} | |
| source = str(md.get("source", "unknown")) | |
| page = md.get("page", md.get("page_number", md.get("pageno", ""))) | |
| score = d.get("score", None) | |
| snippet = (d.get("text", "") or "").replace("\n", " ").strip() | |
| if len(snippet) > 160: | |
| snippet = snippet[:160] + "…" | |
| lines.append(f"| {i} | {score:.4f} | {source} | {page if page is not None else ''} | {snippet} |") | |
| return "\n".join(lines) | |
| # ========================= | |
| # Gradio Wrappers | |
| # ========================= | |
| def gradio_ingest(files, api_key): | |
| if not api_key or not api_key.strip(): | |
| return "❌ Please enter your OpenAI API key before indexing." | |
| if not files: | |
| return "⚠️ Please drop at least one document." | |
| file_paths = files if isinstance(files, list) else [files] | |
| try: | |
| status = add_documents_to_chroma(file_paths, api_key) | |
| except Exception as e: | |
| return f"❌ Error during indexing: {e}" | |
| return status | |
| def gradio_expand(question: str, api_key: str): | |
| if not api_key or not api_key.strip(): | |
| return gr.update(choices=[], value=[]), "❌ Please enter your OpenAI API key first." | |
| expanded = query_expansion(question, api_key) | |
| md = format_expansions_md(expanded) | |
| default_value = expanded[:1] if expanded else [] | |
| return gr.update(choices=expanded, value=default_value), md | |
| def gradio_run_selected(selected_queries: List[str], api_key: str) -> str: | |
| if not api_key or not api_key.strip(): | |
| return "❌ Please enter your OpenAI API key before searching." | |
| if not selected_queries: | |
| return "⚠️ Please expand a question and select one or more to run." | |
| try: | |
| return query_rag_multi(selected_queries, api_key) | |
| except Exception as e: | |
| return f"❌ Error during question answering: {e}" | |
| def gradio_cross_encode(original_question: str, selected_queries: List[str], api_key: str) -> Tuple[str, str]: | |
| """ | |
| Cross-encode button: | |
| - Initial retrieval via Chroma: top_k=20 (per requirement) | |
| - Rerank via cross-encoder: top_n=5 (per requirement) | |
| - Show: | |
| (a) top_n reranked passages, | |
| (b) their scores, | |
| (c) final context string | |
| """ | |
| if not api_key or not api_key.strip(): | |
| return "❌ Please enter your OpenAI API key first.", "" | |
| if collection.count() == 0: | |
| return "⚠️ No documents in the database yet. Upload and index some documents first.", "" | |
| original_question = (original_question or "").strip() | |
| selected_queries = [q.strip() for q in (selected_queries or []) if isinstance(q, str) and q.strip()] | |
| if not original_question and not selected_queries: | |
| return "⚠️ Please type a question and/or select expansions first.", "" | |
| # Retrieval: use selected expansions if present, otherwise fall back to original question | |
| retrieval_queries = selected_queries if selected_queries else [original_question] | |
| # Requirement: Chroma retrieval top_k=20 | |
| retrieved_lists = [retrieve_from_chroma(q, top_k=20, api_key=api_key) for q in retrieval_queries] | |
| merged_docs = _merge_docs_by_id(retrieved_lists) | |
| if not merged_docs: | |
| return "I couldn't find any relevant context in the indexed documents.", "" | |
| # Cross-encoder scoring query: use the original user question if available; else first retrieval query | |
| scoring_query = original_question if original_question else retrieval_queries[0] | |
| # Requirement: rerank top_n=5 | |
| reranked = cross_encoder_rerank(scoring_query, merged_docs, top_n=5) | |
| # Build final context + prompt | |
| context, _prompt = build_prompt(scoring_query, reranked) | |
| # Return: | |
| # (a) reranked passages (shown in table), | |
| # (b) scores (in table), | |
| # (c) final context string (shown separately) | |
| md = format_rerank_results_md(scoring_query, reranked, top_n=5) | |
| return md, f"### 🧩 Final Context (for LLM)\n\n{context}" | |
| # ========================= | |
| # Gradio Interface | |
| # ========================= | |
| with gr.Blocks(title="RAG with ChromaDB") as demo: | |
| gr.Markdown( | |
| """ | |
| # 📚 RAG Q&A with ChromaDB + Gradio (Multi-Select Query Expansion + Cross-Encoder Rerank) | |
| 1. Paste your **OpenAI API key** below. | |
| 2. **Drag & drop** one or more documents into the upload box. | |
| 3. Click **"Index documents"** to store them in a Chroma vector database. | |
| 4. Type a question and press **Enter** (or click **Expand**) to generate expanded queries. | |
| 5. Select **one or more** expanded queries. | |
| 6. Click **Run Search** for the normal pipeline, or **Cross Encode** to view reranked passages + scores + final context. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| api_key_box = gr.Textbox( | |
| label="OpenAI API Key", | |
| placeholder="sk-... (this is kept in memory only for this session)", | |
| type="password", | |
| ) | |
| file_input = gr.File( | |
| label="Drop your document(s) here", | |
| file_count="multiple", | |
| type="filepath", | |
| ) | |
| ingest_button = gr.Button("Index documents") | |
| ingest_status = gr.Markdown("⚙️ Waiting for documents...") | |
| with gr.Column(scale=1): | |
| question_box = gr.Textbox( | |
| label="Type a question, then press Enter to expand", | |
| placeholder="e.g., What are the main findings in the report?", | |
| lines=3, | |
| ) | |
| with gr.Row(): | |
| expand_button = gr.Button("Expand") | |
| run_button = gr.Button("Run Search") | |
| cross_button = gr.Button("Cross Encode") | |
| expanded_checks = gr.CheckboxGroup( | |
| label="Choose one or more expanded queries to run", | |
| choices=[], | |
| value=[], | |
| interactive=True, | |
| ) | |
| expansions_preview = gr.Markdown("*(No expansions yet — type a question and press Enter.)*") | |
| answer_box = gr.Markdown("💬 Answer will appear here (with self-evaluation).") | |
| gr.Markdown("---") | |
| rerank_results_box = gr.Markdown("*(Cross-encoder rerank results will appear here.)*") | |
| rerank_context_box = gr.Markdown("*(Final context for the LLM will appear here.)*") | |
| ingest_button.click( | |
| fn=gradio_ingest, | |
| inputs=[file_input, api_key_box], | |
| outputs=[ingest_status], | |
| ) | |
| # Expand on Enter | |
| question_box.submit( | |
| fn=gradio_expand, | |
| inputs=[question_box, api_key_box], | |
| outputs=[expanded_checks, expansions_preview], | |
| ) | |
| # Expand on button click | |
| expand_button.click( | |
| fn=gradio_expand, | |
| inputs=[question_box, api_key_box], | |
| outputs=[expanded_checks, expansions_preview], | |
| ) | |
| # Run selected expanded queries (existing pipeline) | |
| run_button.click( | |
| fn=gradio_run_selected, | |
| inputs=[expanded_checks, api_key_box], | |
| outputs=[answer_box], | |
| ) | |
| # Cross-encoder rerank (new button + UI outputs) | |
| cross_button.click( | |
| fn=gradio_cross_encode, | |
| inputs=[question_box, expanded_checks, api_key_box], | |
| outputs=[rerank_results_box, rerank_context_box], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |