Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import logging | |
| from dotenv import load_dotenv | |
| import tempfile | |
| import pandas as pd | |
| from typing import List, Dict, Any, Tuple | |
| import shutil | |
| import json | |
| from core.ingest import DocumentChunker, HierarchyManager | |
| from core.index import VectorStore | |
| from core.retrieval import RAGManager | |
| from core.eval import RAGEvaluator | |
| from core.utils import generate_id | |
| import os as _os | |
| _OPENAI_ON = False | |
| try: | |
| from openai import OpenAI as _OpenAI | |
| _OPENAI_ON = True if _os.getenv("OPENAI_API_KEY") else False | |
| except Exception: | |
| _OPENAI_ON = False | |
| def _auto_detect_query_filters(query: str) -> Dict[str, Any]: | |
| """Infer hierarchy filters (level1/2/3, doc_type) from a query. | |
| Prefers OpenAI; falls back to heuristic scan of hierarchy keywords. | |
| """ | |
| # Try OpenAI if available | |
| if os.getenv("OPENAI_API_KEY") and _OpenAI is not None: | |
| try: | |
| logger.debug("Calling OpenAI for query filter detection.") | |
| client = _OpenAI() | |
| prompt = ( | |
| "Given a user query, infer optional filters for a hierarchical RAG system." | |
| " Return JSON with any of: level1, level2, level3, doc_type." | |
| " Use concise strings; omit fields you cannot infer. Query: " + query | |
| ) | |
| resp = client.chat.completions.create( | |
| model=_os.getenv("OPENAI_MODEL", "gpt-4o-mini"), | |
| messages=[{"role": "user", "content": prompt}], | |
| temperature=0.0, | |
| ) | |
| import json as _json | |
| content = resp.choices[0].message.content | |
| data = _json.loads(content) | |
| logger.debug(f"OpenAI filters inferred: {data}") | |
| if isinstance(data, dict) and any(data.get(k) for k in ("level1","level2","level3","doc_type")): | |
| return data | |
| except Exception: | |
| logger.exception("OpenAI query filter detection failed; will try heuristic.") | |
| # Heuristic fallback | |
| try: | |
| hm = HierarchyManager() | |
| ql = query.lower() | |
| best = {"level1": None, "level2": None, "level3": None, "doc_type": None} | |
| best_score = -1 | |
| for hname, hdef in hm.hierarchies.items(): | |
| for l1 in hdef['levels']['level1']['values']: | |
| score = 0 | |
| if l1.lower() in ql: | |
| score += 2 | |
| for l2 in hdef['levels']['level2']['values'].get(l1, []): | |
| if l2.lower() in ql: | |
| score += 2 | |
| for l3 in hdef['levels']['level3']['values'].get(l2, []): | |
| if l3.lower() in ql: | |
| score += 1 | |
| if score > best_score: | |
| best_score = score | |
| best.update({"level1": l1}) | |
| for dt in ["Policy","Manual","FAQ","Report","Note","Guideline"]: | |
| if dt.lower() in ql: | |
| best["doc_type"] = dt | |
| break | |
| logger.debug(f"Heuristic filters inferred: {best if best_score>0 or best.get('doc_type') else {}}") | |
| return best if best_score > 0 or best.get("doc_type") else {} | |
| except Exception: | |
| logger.debug("Heuristic detection failed; returning no filters.") | |
| return {} | |
| # Load environment variables from .env if present, then configure logging | |
| load_dotenv() | |
| logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO")) | |
| logger = logging.getLogger("rag_app") | |
| if os.getenv("OPENAI_API_KEY"): | |
| logger.info("OpenAI API key detected. OpenAI-powered auto-detection is ENABLED.") | |
| if os.getenv("OPENAI_MODEL"): | |
| logger.info(f"OpenAI model: {os.getenv('OPENAI_MODEL')}") | |
| else: | |
| logger.info("OpenAI API key not set. Falling back to heuristic auto-detection.") | |
| # Global variables | |
| rag_manager = None | |
| evaluator = None | |
| current_collection = "documents" | |
| persist_directory = None | |
| def initialize_system(): | |
| """Initialize the RAG system""" | |
| global rag_manager, evaluator, persist_directory | |
| # Try /data/chroma first (for HF Spaces), fallback to ./chroma_data | |
| persist_dir = "/data/chroma" if os.path.exists("/data/chroma") else "./chroma_data" | |
| # Create directory with proper permissions, and check if we can write to it | |
| try: | |
| os.makedirs(persist_dir, exist_ok=True, mode=0o755) | |
| # Test write permissions | |
| test_file = os.path.join(persist_dir, ".test_write") | |
| try: | |
| with open(test_file, 'w') as f: | |
| f.write("test") | |
| os.remove(test_file) | |
| except (PermissionError, OSError): | |
| # If can't write to /data/chroma, use ./chroma_data | |
| persist_dir = "./chroma_data" | |
| os.makedirs(persist_dir, exist_ok=True, mode=0o755) | |
| except (PermissionError, OSError) as e: | |
| # If even ./chroma_data fails, try current directory | |
| persist_dir = "./chroma_data" | |
| os.makedirs(persist_dir, exist_ok=True, mode=0o755) | |
| persist_directory = persist_dir | |
| rag_manager = RAGManager(persist_directory=persist_dir) | |
| evaluator = RAGEvaluator(rag_manager) | |
| return f"System initialized successfully! Using persist directory: {persist_dir}" | |
| def reset_index() -> str: | |
| """Clear Chroma persistence and reinitialize the vector store.""" | |
| global rag_manager, evaluator, persist_directory | |
| try: | |
| dir_path = persist_directory or ("/data/chroma" if os.path.exists("/data/chroma") else "./chroma_data") | |
| if os.path.exists(dir_path): | |
| shutil.rmtree(dir_path, ignore_errors=True) | |
| os.makedirs(dir_path, exist_ok=True, mode=0o755) | |
| persist_directory = dir_path | |
| rag_manager = RAGManager(persist_directory=dir_path) | |
| evaluator = RAGEvaluator(rag_manager) | |
| return f"Index reset complete. Using fresh directory: {dir_path}" | |
| except Exception as ex: | |
| return f"Failed to reset index: {ex}" | |
| def upload_documents(files: List[str], hierarchy: str, doc_type: str, language: str, progress: Any = None) -> Tuple[str, List[Dict[str, Any]], Dict[str, Any], List[Dict[str, Any]]]: | |
| """Upload and process documents. | |
| Returns: (status_text, per_file_summaries, collection_stats) | |
| per_file_summaries: [{filename, chunks, language, doc_type, hierarchy}] | |
| """ | |
| global rag_manager, persist_directory | |
| if not files: | |
| return "No files provided!" | |
| # Ensure system is initialized | |
| if not rag_manager: | |
| initialize_system() | |
| chunker = DocumentChunker() | |
| all_chunks = [] | |
| per_file_summaries: List[Dict[str, Any]] = [] | |
| chunk_rows: List[Dict[str, Any]] = [] | |
| processed_count = 0 | |
| errors: List[str] = [] | |
| total = len(files) | |
| for idx, file_path in enumerate(files, start=1): | |
| if progress: | |
| try: | |
| progress(idx, total=total, desc=f"Processing {idx}/{total}: {os.path.basename(file_path)}") | |
| except Exception: | |
| pass | |
| try: | |
| chunks = chunker.chunk_document(file_path, hierarchy, doc_type, language) | |
| if chunks: | |
| # Aggregate per-file metadata from chunks (majority vote) | |
| from collections import Counter | |
| langs = Counter([c.metadata.get('lang') for c in chunks if c.metadata.get('lang')]) | |
| docts = Counter([c.metadata.get('doc_type') for c in chunks if c.metadata.get('doc_type')]) | |
| # prefer explicit 'hierarchy' if present, else most common level1's hierarchy name is unknown | |
| hier_names = Counter([c.metadata.get('hierarchy') for c in chunks if c.metadata.get('hierarchy')]) | |
| per_file_summaries.append({ | |
| 'Filename': os.path.basename(file_path), | |
| 'Chunks': len(chunks), | |
| 'Language': (langs.most_common(1)[0][0] if langs else None), | |
| 'Doc Type': (docts.most_common(1)[0][0] if docts else None), | |
| 'Hierarchy': (hier_names.most_common(1)[0][0] if hier_names else None) | |
| }) | |
| # Prepare per-chunk preview rows | |
| for c in chunks: | |
| md = c.metadata or {} | |
| chunk_rows.append({ | |
| 'Filename': os.path.basename(file_path), | |
| 'Level1': md.get('level1'), | |
| 'Level2': md.get('level2'), | |
| 'Level3': md.get('level3'), | |
| 'Doc Type': md.get('doc_type'), | |
| 'Language': md.get('lang'), | |
| 'Preview': (c.content[:160] + '...') if c.content else '' | |
| }) | |
| # Log indexing summary per file (without AI percentage) | |
| try: | |
| logger.info("Indexed %s: chunks=%d, lang=%s, doc_type=%s, hierarchy=%s", | |
| os.path.basename(file_path), len(chunks), | |
| (langs.most_common(1)[0][0] if langs else None), | |
| (docts.most_common(1)[0][0] if docts else None), | |
| (hier_names.most_common(1)[0][0] if hier_names else None)) | |
| except Exception: | |
| pass | |
| all_chunks.extend(chunks) | |
| processed_count += 1 | |
| else: | |
| errors.append(f"Warning: {os.path.basename(file_path)} produced no chunks") | |
| except Exception as e: | |
| error_msg = f"{os.path.basename(file_path)}: {str(e)}" | |
| errors.append(error_msg) | |
| # Continue processing other files instead of stopping | |
| # Index if any chunk present | |
| vector_store = rag_manager.vector_store | |
| stats: Dict[str, Any] = {"document_count": 0, "collection_name": current_collection} | |
| if all_chunks: | |
| vector_store.add_documents(current_collection, all_chunks) | |
| stats = vector_store.get_collection_stats(current_collection) | |
| # Build result message | |
| status_lines = [ | |
| f"Processed {processed_count}/{total} files", | |
| f"Indexed chunks: {len(all_chunks)}" | |
| ] | |
| if errors: | |
| status_lines.append("\nErrors/Warnings:\n" + "\n".join(f"- {e}" for e in errors)) | |
| if not all_chunks: | |
| return "\n".join(status_lines), per_file_summaries, stats, chunk_rows | |
| return "\n".join(status_lines), per_file_summaries, stats, chunk_rows | |
| def build_rag_index(files: List[Any], hierarchy: str, doc_type: str, language: str) -> Tuple[str, pd.DataFrame, pd.DataFrame]: | |
| """Build RAG index from uploaded files""" | |
| if not files: | |
| return "No files provided!", None | |
| # Gradio file objects already contain the full path in .name property | |
| # No need to prepend /tmp/ - just use the path directly | |
| file_paths = [] | |
| for file in files: | |
| # Get the file path - Gradio provides it as .name or as a string | |
| if isinstance(file, str): | |
| file_path = file | |
| elif hasattr(file, 'name') and file.name: | |
| # file.name already contains the full path (e.g., /tmp/gradio/.../filename.txt) | |
| file_path = file.name | |
| else: | |
| # Fallback for edge cases | |
| return f"Error: Unable to get file path from uploaded file", None | |
| # Normalize the path to handle any double slashes | |
| file_path = os.path.normpath(file_path) | |
| # Ensure the file exists | |
| if not os.path.exists(file_path): | |
| return f"Error: File not found at {file_path}", None | |
| file_paths.append(file_path) | |
| # Normalize "Auto" to None for auto-detection downstream | |
| norm_hierarchy = None if not hierarchy or str(hierarchy).lower() == 'auto' else hierarchy | |
| norm_doc_type = None if not doc_type or str(doc_type).lower() == 'auto' else doc_type | |
| norm_language = None if not language or str(language).lower() == 'auto' else language | |
| # Process documents (progress provided by gradio) | |
| status_text, file_summaries, stats, chunk_rows = upload_documents(file_paths, norm_hierarchy, norm_doc_type, norm_language, progress=None) | |
| # Build per-file dataframe (no totals row) | |
| per_file_df = pd.DataFrame(file_summaries) if file_summaries else pd.DataFrame(columns=['Filename','Chunks','Language','Doc Type','Hierarchy']) | |
| chunks_df = pd.DataFrame(chunk_rows) if chunk_rows else pd.DataFrame(columns=['Filename','Level1','Level2','Level3','Doc Type','Language','Preview']) | |
| return status_text, per_file_df, chunks_df | |
| def search_documents(query: str, k: int, level1: str, level2: str, level3: str, doc_type: str) -> Tuple[str, str, str, pd.DataFrame]: | |
| """Search documents using both RAG pipelines""" | |
| if not rag_manager: | |
| return "System not initialized!", "", "", None | |
| # Convert empty strings to None | |
| level1 = level1 if level1 else None | |
| level2 = level2 if level2 else None | |
| level3 = level3 if level3 else None | |
| doc_type = doc_type if doc_type else None | |
| # Auto-detect filters from query if none provided | |
| if not any([level1, level2, level3, doc_type]): | |
| inferred = _auto_detect_query_filters(query) | |
| level1 = level1 or inferred.get('level1') | |
| level2 = level2 or inferred.get('level2') | |
| level3 = level3 or inferred.get('level3') | |
| doc_type = doc_type or inferred.get('doc_type') | |
| base_result, hier_result = rag_manager.compare_retrieval(query, k, level1, level2, level3, doc_type) | |
| # Prepare results for display | |
| results_data = [] | |
| for i, source in enumerate(base_result.sources): | |
| results_data.append({ | |
| 'Pipeline': 'Base-RAG', | |
| 'Rank': i + 1, | |
| 'Content': source['content'][:100] + '...', | |
| 'Domain': source['metadata'].get('level1', 'N/A'), | |
| 'Section': source['metadata'].get('level2', 'N/A'), | |
| 'Topic': source['metadata'].get('level3', 'N/A'), | |
| 'Score': f"{source['score']:.3f}" | |
| }) | |
| for i, source in enumerate(hier_result.sources): | |
| results_data.append({ | |
| 'Pipeline': 'Hier-RAG', | |
| 'Rank': i + 1, | |
| 'Content': source['content'][:100] + '...', | |
| 'Domain': source['metadata'].get('level1', 'N/A'), | |
| 'Section': source['metadata'].get('level2', 'N/A'), | |
| 'Topic': source['metadata'].get('level3', 'N/A'), | |
| 'Score': f"{source['score']:.3f}" | |
| }) | |
| results_df = pd.DataFrame(results_data) | |
| comparison_text = f""" | |
| ## Retrieval Comparison | |
| ### Base-RAG | |
| - Latency: {base_result.latency:.3f}s | |
| - Retrieved: {len(base_result.sources)} documents | |
| ### Hier-RAG | |
| - Latency: {hier_result.latency:.3f}s | |
| - Retrieved: {len(hier_result.sources)} documents | |
| - Filters: level1={level1 or 'None'}, level2={level2 or 'None'}, level3={level3 or 'None'}, doc_type={doc_type or 'None'} | |
| """ | |
| return base_result.content, hier_result.content, comparison_text, results_df | |
| def search_documents_auto(query: str) -> Tuple[str, str, str, pd.DataFrame]: | |
| """Search with only a query; k and filters auto-handled. | |
| - k defaults to DEFAULT_SEARCH_K env or 5 | |
| - filters inferred from query when possible | |
| """ | |
| default_k = int(os.getenv("DEFAULT_SEARCH_K", 5)) | |
| return search_documents(query, default_k, None, None, None, None) | |
| def search_documents_unified(query: str, manual: bool, k: int = None, | |
| level1: str = None, level2: str = None, | |
| level3: str = None, doc_type: str = None) -> Tuple[str, str, str, pd.DataFrame]: | |
| """Single entry search. If manual is True, use provided controls; else auto.""" | |
| if manual: | |
| # normalize empty/"Auto" strings to None (auto-detect) | |
| def _norm(v): | |
| return None if (v is None or (isinstance(v, str) and v.strip().lower() == 'auto') or v == '') else v | |
| level1 = _norm(level1) | |
| level2 = _norm(level2) | |
| level3 = _norm(level3) | |
| doc_type = _norm(doc_type) | |
| k = k or int(os.getenv("DEFAULT_SEARCH_K", 5)) | |
| return search_documents(query, k, level1, level2, level3, doc_type) | |
| return search_documents_auto(query) | |
| def _toggle_manual_controls(manual: bool): | |
| """Update manual controls interactive state, accordion, and default values.""" | |
| auto_update = gr.update(value="Auto", interactive=manual) | |
| return ( | |
| gr.update(open=manual), | |
| gr.update(interactive=manual), # k slider | |
| auto_update, # level1 | |
| auto_update, # level2 | |
| auto_update, # level3 | |
| auto_update, # doc_type | |
| ) | |
| def _llm_answer(user_message: str, contexts: List[Dict[str, Any]]) -> str: | |
| """Generate a natural, human-like answer grounded in contexts. | |
| Uses OpenAI if configured; otherwise produces a conversational fallback. | |
| """ | |
| # Build a compact context string | |
| ctx_blocks = [] | |
| for i, c in enumerate(contexts, 1): | |
| src = c.get('metadata', {}).get('source_name', 'unknown') | |
| snippet = (c.get('content', '') or '')[:400] | |
| ctx_blocks.append(f"[{i}] ({src}) {snippet}") | |
| ctx_text = "\n\n".join(ctx_blocks) | |
| # Prefer OpenAI | |
| if os.getenv("OPENAI_API_KEY") and '_OpenAI' in globals() and _OpenAI is not None: | |
| try: | |
| client = _OpenAI() | |
| system_prompt = ( | |
| "You are a helpful, professional assistant. Answer in a warm, natural, and concise tone. " | |
| "ALWAYS ground the answer ONLY in the provided contexts. If information is missing, say so. " | |
| "Style: Start with a clear 1-2 sentence answer. Then, if helpful, add 2-5 short bullet points with key facts. " | |
| "Avoid hedging, avoid citations inline, avoid repeating the question." | |
| ) | |
| content = ( | |
| f"User question:\n{user_message}\n\n" | |
| f"Contexts (each begins with [n]):\n{ctx_text}\n\n" | |
| "Write a natural, human-like response as specified." | |
| ) | |
| resp = client.chat.completions.create( | |
| model=os.getenv("OPENAI_MODEL", "gpt-4o-mini"), | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": content}, | |
| ], | |
| temperature=0.2, | |
| ) | |
| return resp.choices[0].message.content.strip() | |
| except Exception: | |
| pass | |
| # Fallback heuristic: conversational synthesis from top snippets | |
| if contexts: | |
| snippets = [] | |
| for src in contexts[:3]: | |
| txt = (src.get('content') or '').strip() | |
| if txt: | |
| snippets.append(txt) | |
| joined = " ".join(snippets) | |
| # Naive sentence split | |
| sentences = [s.strip() for s in joined.replace('\n', ' ').split('.') if s.strip()] | |
| bullets = sentences[:4] | |
| lead = "Here’s what the documents say:" | |
| if bullets: | |
| bullets_text = "\n".join([f"- {b}." for b in bullets]) | |
| return f"{lead}\n\n{bullets_text}" | |
| return "I don't have enough information to answer that yet." | |
| def chat_with_rag(message: str, history: List[Dict[str, str]], pipeline: str, k: int) -> Tuple[str, List[Dict[str, str]]]: | |
| """Chat interface with RAG: choose one pipeline (base_rag or hier_rag), | |
| retrieve, then generate an LLM answer grounded in retrieved contexts and | |
| show sources used. | |
| """ | |
| if not rag_manager: | |
| return "System not initialized! Please build the RAG index first.", history | |
| # Retrieve with chosen pipeline | |
| filters_note = "" | |
| if pipeline == "hier_rag": | |
| inferred = _auto_detect_query_filters(message) | |
| h_l1 = inferred.get('level1') | |
| h_l2 = inferred.get('level2') | |
| h_l3 = inferred.get('level3') | |
| h_dt = inferred.get('doc_type') | |
| result = rag_manager.hier_rag.retrieve(message, k, h_l1, h_l2, h_l3, h_dt) | |
| # Show filters only when detected | |
| found = [] | |
| if h_l1: | |
| found.append(f"level1={h_l1}") | |
| if h_l2: | |
| found.append(f"level2={h_l2}") | |
| if h_l3: | |
| found.append(f"level3={h_l3}") | |
| if h_dt: | |
| found.append(f"doc_type={h_dt}") | |
| if found: | |
| filters_note = " (filters: " + ", ".join(found) + ")" | |
| else: | |
| result = rag_manager.base_rag.retrieve(message, k) | |
| # Generate grounded answer | |
| answer = _llm_answer(message, result.sources) | |
| # Sources list | |
| src_lines = [] | |
| for s in result.sources[:k]: | |
| src_name = s.get('metadata', {}).get('source_name', 'unknown') | |
| src_lines.append(f"• {src_name}") | |
| response = ( | |
| f"{answer}\n\n" | |
| f"─────────────────────────────────────\n" | |
| f"📎 Sources ({'Hier-RAG' if pipeline=='hier_rag' else 'Base-RAG'}{filters_note}):\n" + ("\n".join(src_lines) or "(none)") | |
| ) | |
| history.append({"role": "user", "content": message}) | |
| history.append({"role": "assistant", "content": response}) | |
| return "", history | |
| def run_evaluation(queries_json: str, output_filename: str) -> Tuple[str, pd.DataFrame, pd.DataFrame]: | |
| """Run quantitative evaluation""" | |
| if not evaluator: | |
| return "System not initialized!", None, None | |
| try: | |
| queries = json.loads(queries_json) | |
| except json.JSONDecodeError as e: | |
| return f"Invalid JSON format: {str(e)}", None, None | |
| df, results = evaluator.batch_evaluate(queries, output_filename) | |
| # Generate summary statistics | |
| summary = df.groupby(['pipeline', 'k']).agg({ | |
| 'hit_at_k': 'mean', | |
| 'mrr': 'mean', | |
| 'semantic_similarity': 'mean', | |
| 'latency': 'mean', | |
| 'retrieved_count': 'mean' | |
| }).reset_index() | |
| # Create comparison plot data | |
| plot_data = pd.DataFrame({ | |
| 'k': summary[summary['pipeline'] == 'base_rag']['k'], | |
| 'base_rag_hit@k': summary[summary['pipeline'] == 'base_rag']['hit_at_k'], | |
| 'hier_rag_hit@k': summary[summary['pipeline'] == 'hier_rag']['hit_at_k'], | |
| 'base_rag_mrr': summary[summary['pipeline'] == 'base_rag']['mrr'], | |
| 'hier_rag_mrr': summary[summary['pipeline'] == 'hier_rag']['mrr'] | |
| }) | |
| return f"Evaluation completed! Processed {len(queries)} queries. Results saved to {output_filename}", df, plot_data | |
| # Diagnostics: simple OpenAI connectivity test | |
| ## (removed) test_openai_connectivity helper | |
| # Initialize system | |
| initialize_system() | |
| # Create Gradio interface | |
| # Minimal CSS to keep layout stable when vertical scrollbar appears and improve mobile spacing | |
| APP_CSS = """ | |
| html, body { scrollbar-gutter: stable both-edges; } | |
| body { overflow-y: scroll; } | |
| * { box-sizing: border-box; } | |
| @media (max-width: 768px) { | |
| .gradio-container { padding-left: 8px; padding-right: 8px; } | |
| } | |
| """ | |
| with gr.Blocks(title="RAG Evaluation System", css=APP_CSS) as demo: | |
| gr.Markdown("# RAG Evaluation System: Hierarchical vs Standard RAG") | |
| with gr.Tab("Upload Documents"): | |
| gr.Markdown("## Upload and Process Documents") | |
| with gr.Row(): | |
| with gr.Column(): | |
| file_upload = gr.File( | |
| label="Upload PDF/TXT Files", | |
| file_count="multiple", | |
| file_types=[".pdf", ".txt"] | |
| ) | |
| hierarchy_dropdown = gr.Dropdown( | |
| choices=["Auto", "hospital", "bank", "fluid_simulation"], | |
| label="Select Hierarchy", | |
| value="Auto" | |
| ) | |
| doc_type_dropdown = gr.Dropdown( | |
| choices=["Auto", "Policy", "Manual", "FAQ", "Report", "Note", "Guideline"], | |
| label="Document Type", | |
| value="Auto" | |
| ) | |
| language_dropdown = gr.Dropdown( | |
| choices=["Auto", "en", "ja"], | |
| label="Language", | |
| value="Auto" | |
| ) | |
| build_btn = gr.Button("Build RAG Index") | |
| with gr.Column(): | |
| build_output = gr.Textbox(label="Build Status", lines=4) | |
| stats_table = gr.DataFrame(label="File Summary") | |
| chunks_table = gr.DataFrame(label="Indexed Chunks (preview)") | |
| reset_btn = gr.Button("Reset Index (Clear chroma_data)", variant="secondary") | |
| with gr.Tab("Search"): | |
| gr.Markdown("## Compare Retrieval Pipelines") | |
| with gr.Row(): | |
| with gr.Column(): | |
| manual_toggle = gr.Checkbox(label="Manual controls", value=False) | |
| search_query = gr.Textbox( | |
| label="Search Query", | |
| placeholder="Enter your query...", | |
| lines=2 | |
| ) | |
| # Single Search button | |
| search_btn = gr.Button("Search", variant="primary") | |
| with gr.Accordion("Manual (optional)", open=False) as manual_section: | |
| k_slider = gr.Slider(minimum=1, maximum=20, value=int(os.getenv("DEFAULT_SEARCH_K", 5)), step=1, label="Number of results (k)", interactive=False) | |
| with gr.Row(): | |
| level1_filter = gr.Dropdown(label="Domain (Level1)", choices=["Auto"], value="Auto", allow_custom_value=True, interactive=False) | |
| level2_filter = gr.Dropdown(label="Section (Level2)", choices=["Auto"], value="Auto", allow_custom_value=True, interactive=False) | |
| with gr.Row(): | |
| level3_filter = gr.Dropdown(label="Topic (Level3)", choices=["Auto"], value="Auto", allow_custom_value=True, interactive=False) | |
| doc_type_filter = gr.Dropdown(label="Document Type", choices=["Auto", "Policy", "Manual", "FAQ", "Report", "Note", "Guideline"], value="Auto", allow_custom_value=True, interactive=False) | |
| with gr.Column(): | |
| base_results = gr.Textbox( | |
| label="Base-RAG Results", | |
| lines=8, | |
| max_lines=12 | |
| ) | |
| hier_results = gr.Textbox( | |
| label="Hier-RAG Results", | |
| lines=8, | |
| max_lines=12 | |
| ) | |
| comparison_text = gr.Markdown() | |
| results_table = gr.DataFrame(label="Detailed Results") | |
| with gr.Tab("Chat"): | |
| gr.Markdown("## Chat with RAG System") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| pipeline_radio = gr.Radio( | |
| choices=["base_rag", "hier_rag"], | |
| label="RAG Pipeline", | |
| value="hier_rag" | |
| ) | |
| chat_k_slider = gr.Slider( | |
| minimum=1, maximum=10, value=3, step=1, | |
| label="Number of results" | |
| ) | |
| with gr.Column(scale=2): | |
| chatbot = gr.Chatbot(label="RAG Chat", type="messages") | |
| chat_input = gr.Textbox( | |
| label="Message", | |
| placeholder="Ask a question...", | |
| lines=2 | |
| ) | |
| chat_btn = gr.Button("Send") | |
| with gr.Tab("Evaluation"): | |
| gr.Markdown("## Quantitative Evaluation") | |
| with gr.Row(): | |
| with gr.Column(): | |
| eval_queries = gr.Textbox( | |
| label="Evaluation Queries (JSON)", | |
| lines=10, | |
| placeholder='''[ | |
| { | |
| "query": "question 1", | |
| "ground_truth": ["expected answer 1", "expected answer 2"], | |
| "k_values": [1, 3, 5], | |
| "level1": "Clinical", | |
| "level2": "Emergency", | |
| "level3": "Triage", | |
| "doc_type": "Report" | |
| } | |
| ]''', | |
| value='''[ | |
| { | |
| "query": "What are the emergency procedures?", | |
| "ground_truth": ["Emergency protocols for triage", "Patient assessment guidelines"], | |
| "k_values": [1, 3, 5] | |
| } | |
| ]''' | |
| ) | |
| eval_output_name = gr.Textbox( | |
| label="Output Filename", | |
| value="evaluation_results.csv" | |
| ) | |
| eval_btn = gr.Button("Run Evaluation", variant="primary") | |
| # Diagnostics removed | |
| with gr.Column(): | |
| eval_output = gr.Textbox(label="Evaluation Status", lines=3) | |
| eval_results_table = gr.DataFrame(label="Evaluation Results") | |
| eval_plot = gr.LinePlot( | |
| label="Performance Comparison", | |
| x="k", | |
| y=["base_rag_hit@k", "hier_rag_hit@k"], | |
| title="Hit@k Comparison", | |
| width=600, | |
| height=400 | |
| ) | |
| with gr.Tab("Settings"): | |
| gr.Markdown("## Settings") | |
| gr.Markdown("Configure embedding models and system preferences.") | |
| with gr.Accordion("Embedding Configuration", open=True): | |
| gr.Markdown("**Select the embedding provider and model.** Switching providers requires re-indexing your documents.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| emb_provider = gr.Radio( | |
| choices=["SentenceTransformers", "OpenAI"], | |
| value="SentenceTransformers", | |
| label="Embeddings Provider", | |
| info="Choose between local SentenceTransformers models or OpenAI embeddings (requires API key)" | |
| ) | |
| with gr.Row(): | |
| apply_embed_btn = gr.Button("Apply Embedding Settings", variant="primary") | |
| with gr.Row(): | |
| with gr.Column(): | |
| st_model_in = gr.Textbox( | |
| label="SentenceTransformers Model", | |
| value=os.getenv("ST_EMBED_MODEL", "all-MiniLM-L6-v2"), | |
| interactive=False, | |
| info="Local embedding model (384 dimensions)" | |
| ) | |
| with gr.Column(): | |
| oai_model_in = gr.Textbox( | |
| label="OpenAI Embedding Model", | |
| value=os.getenv("OPENAI_EMBED_MODEL", "text-embedding-3-small"), | |
| interactive=False, | |
| info="OpenAI embedding model (1536 dimensions for small, 3072 for large)" | |
| ) | |
| embed_status = gr.Textbox( | |
| label="Status", | |
| lines=3, | |
| interactive=False, | |
| placeholder="Embedding configuration status will appear here..." | |
| ) | |
| # Define handler before wiring it | |
| def _apply_embeddings(provider, st_model, oai_model): | |
| try: | |
| use_oai = (provider == "OpenAI") | |
| rag_manager.vector_store.configure_embeddings(use_oai, openai_model=oai_model, st_model_name=st_model) | |
| status_msg = f"✅ Embeddings successfully configured!\n\n" | |
| status_msg += f"Provider: {provider}\n" | |
| if use_oai: | |
| status_msg += f"Model: {oai_model} (OpenAI)\n" | |
| status_msg += f"Dimensions: {3072 if 'large' in oai_model.lower() else 1536}\n" | |
| else: | |
| status_msg += f"Model: {st_model} (SentenceTransformers)\n" | |
| status_msg += f"Dimensions: ~384\n" | |
| status_msg += f"\n⚠️ Note: If switching providers, reset and rebuild your index in the Upload tab." | |
| return status_msg | |
| except Exception as ex: | |
| return f"❌ Failed to set embeddings: {ex}\n\nPlease check your configuration and try again." | |
| apply_embed_btn.click( | |
| fn=_apply_embeddings, | |
| inputs=[emb_provider, st_model_in, oai_model_in], | |
| outputs=embed_status | |
| ) | |
| # Event handlers | |
| build_btn.click( | |
| fn=build_rag_index, | |
| inputs=[file_upload, hierarchy_dropdown, doc_type_dropdown, language_dropdown], | |
| outputs=[build_output, stats_table, chunks_table], | |
| api_name="build_rag" | |
| ) | |
| reset_btn.click( | |
| fn=reset_index, | |
| inputs=None, | |
| outputs=build_output | |
| ) | |
| # One button: uses manual if checked; otherwise auto | |
| search_btn.click( | |
| fn=search_documents_unified, | |
| inputs=[search_query, manual_toggle, k_slider, level1_filter, level2_filter, level3_filter, doc_type_filter], | |
| outputs=[base_results, hier_results, comparison_text, results_table], | |
| api_name="search" | |
| ) | |
| # Toggle manual controls interactive and accordion state | |
| manual_toggle.change( | |
| fn=_toggle_manual_controls, | |
| inputs=[manual_toggle], | |
| outputs=[manual_section, k_slider, level1_filter, level2_filter, level3_filter, doc_type_filter] | |
| ) | |
| chat_btn.click( | |
| fn=chat_with_rag, | |
| inputs=[chat_input, chatbot, pipeline_radio, chat_k_slider], | |
| outputs=[chat_input, chatbot], | |
| api_name="chat" | |
| ).then( | |
| lambda: None, | |
| None, | |
| chat_input, | |
| queue=False | |
| ) | |
| eval_btn.click( | |
| fn=run_evaluation, | |
| inputs=[eval_queries, eval_output_name], | |
| outputs=[eval_output, eval_results_table, eval_plot], | |
| api_name="evaluate" | |
| ) | |
| # Diagnostics trigger removed | |
| # MCP Server Implementation | |
| import asyncio | |
| import sys | |
| from typing import Any, List, Optional | |
| try: | |
| from mcp.server import Server | |
| from mcp.server.models import InitializationOptions | |
| from mcp.types import Tool, TextContent | |
| MCP_AVAILABLE = True | |
| except ImportError: | |
| MCP_AVAILABLE = False | |
| # Fallback for when MCP is not installed | |
| Server = None | |
| Tool = None | |
| TextContent = None | |
| class RAGMCPServer: | |
| """MCP server for RAG system""" | |
| def __init__(self): | |
| persist_dir = "/data/chroma" if os.path.exists("/data/chroma") else "./chroma_data" | |
| self.rag_manager = RAGManager(persist_directory=persist_dir) | |
| self.evaluator = RAGEvaluator(self.rag_manager) | |
| async def list_tools(self) -> List[Tool]: | |
| """List available MCP tools""" | |
| return [ | |
| Tool( | |
| name="search_documents", | |
| description="Search documents using RAG system (Base-RAG or Hier-RAG)", | |
| inputSchema={ | |
| "type": "object", | |
| "properties": { | |
| "query": {"type": "string", "description": "Search query"}, | |
| "k": {"type": "integer", "description": "Number of results", "default": 5}, | |
| "pipeline": {"type": "string", "enum": ["base_rag", "hier_rag"], "default": "base_rag"}, | |
| "level1": {"type": "string", "description": "Level1 filter (domain)"}, | |
| "level2": {"type": "string", "description": "Level2 filter (section)"}, | |
| "level3": {"type": "string", "description": "Level3 filter (topic)"}, | |
| "doc_type": {"type": "string", "description": "Document type filter"} | |
| }, | |
| "required": ["query"] | |
| } | |
| ), | |
| Tool( | |
| name="evaluate_retrieval", | |
| description="Evaluate RAG performance with batch queries", | |
| inputSchema={ | |
| "type": "object", | |
| "properties": { | |
| "queries": { | |
| "type": "array", | |
| "description": "List of query objects with query, ground_truth, k_values, and optional filters", | |
| "items": {"type": "object"} | |
| }, | |
| "output_file": {"type": "string", "description": "Output filename for results"} | |
| }, | |
| "required": ["queries"] | |
| } | |
| ) | |
| ] | |
| async def call_tool(self, name: str, arguments: dict) -> List[TextContent]: | |
| """Call an MCP tool by name""" | |
| if name == "search_documents": | |
| query = arguments.get("query") | |
| k = arguments.get("k", 5) | |
| pipeline = arguments.get("pipeline", "base_rag") | |
| level1 = arguments.get("level1") | |
| level2 = arguments.get("level2") | |
| level3 = arguments.get("level3") | |
| doc_type = arguments.get("doc_type") | |
| if pipeline == "base_rag": | |
| result = self.rag_manager.base_rag.retrieve(query, k) | |
| else: | |
| result = self.rag_manager.hier_rag.retrieve(query, k, level1, level2, level3, doc_type) | |
| response = { | |
| "content": result.content, | |
| "sources": [ | |
| { | |
| "content": source['content'][:200], | |
| "metadata": source['metadata'], | |
| "score": source['score'] | |
| } for source in result.sources | |
| ], | |
| "latency": result.latency, | |
| "strategy": pipeline | |
| } | |
| return [TextContent(type="text", text=json.dumps(response, indent=2))] | |
| elif name == "evaluate_retrieval": | |
| queries = arguments.get("queries", []) | |
| output_file = arguments.get("output_file") | |
| df, results = self.evaluator.batch_evaluate(queries, output_file) | |
| summary = df.groupby('pipeline').agg({ | |
| 'hit_at_k': 'mean', | |
| 'mrr': 'mean', | |
| 'semantic_similarity': 'mean', | |
| 'latency': 'mean' | |
| }).reset_index() | |
| response = { | |
| "summary": summary.to_dict('records'), | |
| "total_queries": len(queries), | |
| "output_file": output_file | |
| } | |
| return [TextContent(type="text", text=json.dumps(response, indent=2))] | |
| else: | |
| raise ValueError(f"Unknown tool: {name}") | |
| # Export for Gradio Client | |
| if __name__ == "__main__": | |
| # If run as CLI, prefer plain Gradio serving. Spaces will import demo directly. | |
| # Respect common hosting env vars. | |
| host = os.getenv("HOST", "0.0.0.0") | |
| port = int(os.getenv("PORT", os.getenv("GRADIO_SERVER_PORT", 7860))) | |
| # Avoid SSR and API schema on Spaces to prevent response length errors | |
| demo.launch(server_name=host, server_port=port, share=False, ssl_verify=False, ssr_mode=False) |