Spaces:
Sleeping
Sleeping
| """Main Gradio application for RAG evaluation.""" | |
| import gradio as gr | |
| import os | |
| from pathlib import Path | |
| from typing import List, Dict, Any, Optional, Tuple | |
| import json | |
| import pandas as pd | |
| from datetime import datetime | |
| from core.ingest import DocumentProcessor | |
| from core.index import IndexManager | |
| from core.retrieval import RAGComparator | |
| from core.eval import RAGEvaluator, BenchmarkDataset | |
| from core.utils import load_hierarchy, save_json | |
| from dotenv import load_dotenv | |
| # app.py - Add at the top after imports | |
| import logging | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| handlers=[ | |
| logging.FileHandler('app.log'), | |
| logging.StreamHandler() | |
| ] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Load environment variables | |
| load_dotenv() | |
| # Verify API key is loaded | |
| if not os.getenv("OPENAI_API_KEY"): | |
| print("β οΈ WARNING: OPENAI_API_KEY not found in environment!") | |
| else: | |
| print("β OpenAI API key loaded successfully") | |
| # Global state | |
| index_manager = None | |
| rag_comparator = None | |
| evaluator = None | |
| current_hierarchy = "hospital" | |
| current_collection = "rag_documents" | |
| # Initialize | |
| # Update initialize_system function with better error handling | |
| def initialize_system(): | |
| """Initialize the RAG system.""" | |
| global index_manager, evaluator | |
| try: | |
| persist_dir = os.getenv("VECTOR_DB_PATH", "./data/chroma") | |
| embedding_model = os.getenv("EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2") | |
| # Check OpenAI API key | |
| api_key = os.getenv("OPENAI_API_KEY") | |
| if not api_key: | |
| return "β **ERROR**: OPENAI_API_KEY not found! Please set it in your .env file or Space Secrets." | |
| # Test API key validity | |
| try: | |
| from openai import OpenAI | |
| client = OpenAI(api_key=api_key) | |
| # Make a minimal API call to verify | |
| client.models.list() | |
| logger.info("β OpenAI API key validated successfully") | |
| except Exception as e: | |
| return f"β **ERROR**: Invalid OpenAI API key. {str(e)}" | |
| # Initialize components | |
| index_manager = IndexManager( | |
| persist_directory=persist_dir, | |
| embedding_model_name=embedding_model | |
| ) | |
| evaluator = RAGEvaluator(embedding_model_name=embedding_model) | |
| logger.info("System initialized successfully") | |
| return """β **System initialized successfully!** | |
| **Components loaded:** | |
| - β Vector Database: ChromaDB | |
| - β Embedding Model: sentence-transformers/all-MiniLM-L6-v2 | |
| - β LLM: OpenAI GPT-3.5-turbo | |
| - β Evaluation Metrics: Ready | |
| **Next steps:** | |
| 1. Go to "Upload Documents" tab | |
| 2. Upload your PDF/TXT files | |
| 3. Select appropriate hierarchy | |
| 4. Build the RAG index""" | |
| except Exception as e: | |
| logger.error(f"Initialization failed: {str(e)}") | |
| return f"β **Initialization failed**: {str(e)}\n\nPlease check your configuration and try again." | |
| def upload_documents( | |
| files: List[Any], # Changed from List[str] | |
| hierarchy_choice: str, | |
| mask_pii: bool = False, | |
| progress=gr.Progress() | |
| ) -> Tuple[str, str, Dict[str, Any]]: | |
| """ | |
| Upload and validate documents. | |
| Args: | |
| files: List of uploaded file objects | |
| hierarchy_choice: Selected hierarchy (hospital, bank, fluid_simulation) | |
| mask_pii: Whether to mask PII | |
| progress: Gradio progress tracker | |
| Returns: | |
| Tuple of (status_message, preview_text, upload_stats) | |
| """ | |
| if not files: | |
| return "No files uploaded.", "", {} | |
| # Validate file extensions | |
| valid_extensions = {'.pdf', '.txt'} | |
| invalid_files = [] | |
| valid_files = [] | |
| for file_obj in files: | |
| # Handle both file path strings and file objects | |
| if hasattr(file_obj, 'name'): | |
| file_path = file_obj.name | |
| else: | |
| file_path = str(file_obj) | |
| ext = Path(file_path).suffix.lower() | |
| if ext in valid_extensions: | |
| valid_files.append(file_path) | |
| else: | |
| invalid_files.append(Path(file_path).name) | |
| stats = { | |
| "total_uploaded": len(files), | |
| "valid_files": len(valid_files), | |
| "invalid_files": len(invalid_files), | |
| "hierarchy": hierarchy_choice | |
| } | |
| # Generate preview | |
| preview_lines = [f"Uploaded {len(files)} file(s)\n"] | |
| preview_lines.append(f"Valid: {len(valid_files)}, Invalid: {len(invalid_files)}\n") | |
| preview_lines.append(f"Selected Hierarchy: {hierarchy_choice}\n") | |
| if valid_files: | |
| preview_lines.append("\nValid Files:") | |
| for f in valid_files[:5]: # Show first 5 | |
| preview_lines.append(f" - {Path(f).name}") | |
| if len(valid_files) > 5: | |
| preview_lines.append(f" ... and {len(valid_files) - 5} more") | |
| if invalid_files: | |
| preview_lines.append("\nInvalid Files (skipped):") | |
| for f in invalid_files: | |
| preview_lines.append(f" - {f}") | |
| preview_text = "\n".join(preview_lines) | |
| if valid_files: | |
| status = f"β {len(valid_files)} files ready for processing." | |
| else: | |
| status = "β No valid files to process." | |
| return status, preview_text, stats | |
| # Update build_rag_index with better progress tracking | |
| def build_rag_index( | |
| files: List[Any], # Changed from List[str] | |
| hierarchy_choice: str, | |
| chunk_size: int = 512, | |
| chunk_overlap: int = 50, | |
| mask_pii: bool = False, | |
| collection_name: str = "rag_documents", | |
| use_llm_classification: bool = True, | |
| progress=gr.Progress() | |
| ) -> Tuple[str, Dict[str, Any]]: | |
| """ | |
| Build RAG index from uploaded documents. | |
| Args: | |
| files: List of uploaded file objects | |
| hierarchy_choice: Selected hierarchy | |
| chunk_size: Chunk size in tokens | |
| chunk_overlap: Overlap between chunks | |
| mask_pii: Whether to mask PII | |
| collection_name: Collection name | |
| use_llm_classification: Use LLM for better classification | |
| progress: Gradio progress tracker | |
| Returns: | |
| Tuple of (status_message, index_stats) | |
| """ | |
| global index_manager, rag_comparator, current_hierarchy, current_collection | |
| if not files: | |
| return "β No files to process.", {} | |
| try: | |
| # Convert file objects to paths | |
| valid_files = [] | |
| for file_obj in files: | |
| if hasattr(file_obj, 'name'): | |
| file_path = file_obj.name | |
| else: | |
| file_path = str(file_obj) | |
| ext = Path(file_path).suffix.lower() | |
| if ext in {'.pdf', '.txt'}: | |
| valid_files.append(file_path) | |
| if not valid_files: | |
| return "β No valid files to process.", {} | |
| # Initialize processor | |
| progress(0.05, desc="π§ Initializing document processor...") | |
| logger.info(f"Starting index build: {len(valid_files)} files, hierarchy={hierarchy_choice}") | |
| processor = DocumentProcessor( | |
| hierarchy_name=hierarchy_choice, | |
| chunk_size=chunk_size, | |
| chunk_overlap=chunk_overlap, | |
| mask_pii=mask_pii, | |
| use_llm_classification=use_llm_classification | |
| ) | |
| # Process documents | |
| progress(0.15, desc="π Processing documents...") | |
| all_chunks = [] | |
| for i, filepath in enumerate(valid_files): | |
| file_progress = 0.15 + (0.50 * i / len(valid_files)) | |
| progress(file_progress, desc=f"π Processing {Path(filepath).name}... ({i+1}/{len(valid_files)})") | |
| try: | |
| chunks = processor.process_document(filepath) | |
| all_chunks.extend(chunks) | |
| logger.info(f"Processed {filepath}: {len(chunks)} chunks") | |
| except Exception as e: | |
| logger.error(f"Error processing {filepath}: {str(e)}") | |
| continue | |
| if not all_chunks: | |
| return "β No chunks extracted from documents. Please check your files.", {} | |
| progress(0.65, desc=f"πΎ Extracted {len(all_chunks)} chunks, building vector index...") | |
| logger.info(f"Total chunks extracted: {len(all_chunks)}") | |
| # Index documents | |
| current_hierarchy = hierarchy_choice | |
| current_collection = collection_name | |
| progress(0.75, desc="π Generating embeddings...") | |
| stats = index_manager.index_documents(all_chunks, collection_name) | |
| # Initialize RAG comparator | |
| progress(0.85, desc="π€ Initializing RAG pipelines...") | |
| vector_store = index_manager.get_store(collection_name) | |
| api_key = os.getenv("OPENAI_API_KEY") | |
| llm_model = os.getenv("LLM_MODEL", "gpt-3.5-turbo") | |
| rag_comparator = RAGComparator( | |
| vector_store=vector_store, | |
| llm_model=llm_model, | |
| api_key=api_key | |
| ) | |
| progress(1.0, desc="β Complete!") | |
| stats_display = { | |
| "β Status": "Successfully indexed", | |
| "π¦ Total Chunks": stats.get("chunks_added", 0), | |
| "ποΈ Collection": collection_name, | |
| "π·οΈ Hierarchy": hierarchy_choice, | |
| "π§ Embedding Model": stats.get("model_name", "Unknown"), | |
| "π Embedding Dimension": stats.get("embedding_dimension", 0), | |
| "π€ LLM Classification": "Enabled" if use_llm_classification else "Disabled" | |
| } | |
| status = f"""β **Successfully indexed {stats.get('chunks_added', 0)} chunks!** | |
| **Index Details:** | |
| - Collection: `{collection_name}` | |
| - Hierarchy: `{hierarchy_choice}` | |
| - Classification: {"LLM-based (high accuracy)" if use_llm_classification else "Keyword-based (faster)"} | |
| **Next steps:** | |
| 1. Go to "Search" tab to test queries | |
| 2. Or go to "Chat" tab for conversational interface | |
| 3. Or run "Evaluate" to get quantitative metrics""" | |
| logger.info(f"Index built successfully: {stats.get('chunks_added', 0)} chunks") | |
| return status, stats_display | |
| except Exception as e: | |
| logger.error(f"Error building index: {str(e)}") | |
| import traceback | |
| return f"β **Error building index**: {str(e)}\n\n```\n{traceback.format_exc()}\n```", {} | |
| def search_rag( | |
| query: str, | |
| pipeline: str, | |
| n_results: int = 5, | |
| level1: str = "", | |
| level2: str = "", | |
| level3: str = "", | |
| doc_type: str = "", | |
| auto_infer: bool = True | |
| ) -> Tuple[str, str, str]: | |
| """ | |
| Search RAG system with a query. | |
| Args: | |
| query: Search query | |
| pipeline: Pipeline to use (Base-RAG, Hier-RAG, or Both) | |
| n_results: Number of results | |
| level1: Level 1 filter | |
| level2: Level 2 filter | |
| level3: Level 3 filter | |
| doc_type: Document type filter | |
| auto_infer: Auto-infer filters | |
| Returns: | |
| Tuple of (answer, contexts, metadata) | |
| """ | |
| global rag_comparator | |
| if not rag_comparator: | |
| return "Please build the RAG index first.", "", "" | |
| if not query.strip(): | |
| return "Please enter a query.", "", "" | |
| try: | |
| # Convert empty strings to None | |
| level1 = level1 if level1.strip() else None | |
| level2 = level2 if level2.strip() else None | |
| level3 = level3 if level3.strip() else None | |
| doc_type = doc_type if doc_type.strip() else None | |
| if pipeline == "Both": | |
| result = rag_comparator.compare( | |
| query=query, | |
| n_results=n_results, | |
| level1=level1, | |
| level2=level2, | |
| level3=level3, | |
| doc_type=doc_type, | |
| auto_infer=auto_infer | |
| ) | |
| answer = f"**Base-RAG Answer:**\n{result['base_rag']['answer']}\n\n" | |
| answer += f"**Hier-RAG Answer:**\n{result['hier_rag']['answer']}\n\n" | |
| answer += f"**Speedup:** {result['speedup']:.2f}x" | |
| contexts = "**Base-RAG Contexts:**\n" | |
| for i, ctx in enumerate(result['base_rag']['contexts'][:3], 1): | |
| contexts += f"\n{i}. {ctx['document'][:200]}...\n" | |
| contexts += "\n**Hier-RAG Contexts:**\n" | |
| for i, ctx in enumerate(result['hier_rag']['contexts'][:3], 1): | |
| contexts += f"\n{i}. {ctx['document'][:200]}...\n" | |
| metadata = f"**Base-RAG Timing:**\n" | |
| metadata += f" Retrieval: {result['base_rag']['retrieval_time']:.3f}s\n" | |
| metadata += f" Generation: {result['base_rag']['generation_time']:.3f}s\n" | |
| metadata += f" Total: {result['base_rag']['total_time']:.3f}s\n\n" | |
| metadata += f"**Hier-RAG Timing:**\n" | |
| metadata += f" Retrieval: {result['hier_rag']['retrieval_time']:.3f}s\n" | |
| metadata += f" Generation: {result['hier_rag']['generation_time']:.3f}s\n" | |
| metadata += f" Total: {result['hier_rag']['total_time']:.3f}s\n\n" | |
| if 'applied_filters' in result['hier_rag']: | |
| metadata += f"**Applied Filters:**\n" | |
| for key, val in result['hier_rag']['applied_filters'].items(): | |
| if val: | |
| metadata += f" {key}: {val}\n" | |
| elif pipeline == "Base-RAG": | |
| result = rag_comparator.base_rag.query(query, n_results) | |
| answer = result['answer'] | |
| contexts = "" | |
| for i, ctx in enumerate(result['contexts'][:5], 1): | |
| contexts += f"\n**Context {i}:**\n{ctx['document'][:300]}...\n" | |
| metadata = f"**Timing:**\n" | |
| metadata += f" Retrieval: {result['retrieval_time']:.3f}s\n" | |
| metadata += f" Generation: {result['generation_time']:.3f}s\n" | |
| metadata += f" Total: {result['total_time']:.3f}s\n" | |
| else: # Hier-RAG | |
| result = rag_comparator.hier_rag.query( | |
| query=query, | |
| n_results=n_results, | |
| level1=level1, | |
| level2=level2, | |
| level3=level3, | |
| doc_type=doc_type, | |
| auto_infer=auto_infer | |
| ) | |
| answer = result['answer'] | |
| contexts = "" | |
| for i, ctx in enumerate(result['contexts'][:5], 1): | |
| contexts += f"\n**Context {i}:**\n{ctx['document'][:300]}...\n" | |
| metadata = f"**Timing:**\n" | |
| metadata += f" Retrieval: {result['retrieval_time']:.3f}s\n" | |
| metadata += f" Generation: {result['generation_time']:.3f}s\n" | |
| metadata += f" Total: {result['total_time']:.3f}s\n\n" | |
| if 'applied_filters' in result: | |
| metadata += f"**Applied Filters:**\n" | |
| for key, val in result['applied_filters'].items(): | |
| if val: | |
| metadata += f" {key}: {val}\n" | |
| return answer, contexts, metadata | |
| except Exception as e: | |
| return f"Error: {str(e)}", "", "" | |
| def chat_interface( | |
| message: str, | |
| history: List[Tuple[str, str]], | |
| pipeline: str, | |
| n_results: int | |
| ) -> Tuple[List[Tuple[str, str]], str]: | |
| """ | |
| Chat interface for conversational queries. | |
| Args: | |
| message: User message | |
| history: Chat history | |
| pipeline: Pipeline to use | |
| n_results: Number of results | |
| Returns: | |
| Tuple of (updated_history, sources) | |
| """ | |
| global rag_comparator | |
| if not rag_comparator: | |
| history.append((message, "Please build the RAG index first.")) | |
| return history, "" | |
| try: | |
| if pipeline == "Base-RAG": | |
| result = rag_comparator.base_rag.query(message, n_results) | |
| else: # Hier-RAG | |
| result = rag_comparator.hier_rag.query(message, n_results, auto_infer=True) | |
| answer = result['answer'] | |
| # Format sources | |
| sources = "**Sources:**\n" | |
| for i, ctx in enumerate(result['contexts'][:3], 1): | |
| meta = ctx.get('metadata', {}) | |
| sources += f"\n{i}. Source: {meta.get('source_name', 'Unknown')}\n" | |
| sources += f" Level1: {meta.get('level1', 'N/A')}, Level2: {meta.get('level2', 'N/A')}\n" | |
| sources += f" Preview: {ctx['document'][:150]}...\n" | |
| history.append((message, answer)) | |
| return history, sources | |
| except Exception as e: | |
| history.append((message, f"Error: {str(e)}")) | |
| return history, "" | |
| # app.py - Update run_evaluation function | |
| # app.py - Fix the run_evaluation function | |
| def run_evaluation( | |
| query_dataset: str, | |
| n_queries: int = 10, | |
| k_values: str = "1,3,5", | |
| progress=gr.Progress() | |
| ) -> Tuple[str, str, str]: | |
| """ | |
| Run quantitative evaluation. | |
| Args: | |
| query_dataset: Dataset selection (hospital, bank, fluid_simulation) | |
| n_queries: Number of queries to evaluate | |
| k_values: Comma-separated k values | |
| progress: Progress tracker | |
| Returns: | |
| Tuple of (summary, csv_path, visualization_path) | |
| """ | |
| global rag_comparator, evaluator | |
| if not rag_comparator or not evaluator: | |
| return "Please build the RAG index first.", "", None | |
| try: | |
| # Parse k values | |
| k_list = [int(k.strip()) for k in k_values.split(',')] | |
| # Get benchmark queries | |
| benchmark = BenchmarkDataset() | |
| if query_dataset == "hospital": | |
| queries = benchmark.get_sample_hospital_queries() | |
| elif query_dataset == "bank": | |
| queries = benchmark.get_sample_bank_queries() | |
| else: | |
| queries = benchmark.get_sample_fluid_simulation_queries() | |
| queries = queries[:n_queries] | |
| results = [] | |
| for i, query_data in enumerate(queries): | |
| progress((i / len(queries)), desc=f"Evaluating query {i+1}/{len(queries)}...") | |
| query = query_data['query'] | |
| # Run comparison | |
| comparison = rag_comparator.compare(query=query, n_results=5, auto_infer=True) | |
| result = { | |
| "query": query, | |
| "expected_domain": query_data.get('domain', 'N/A'), | |
| "base_retrieval_time": comparison['base_rag']['retrieval_time'], | |
| "base_total_time": comparison['base_rag']['total_time'], | |
| "hier_retrieval_time": comparison['hier_rag']['retrieval_time'], | |
| "hier_total_time": comparison['hier_rag']['total_time'], | |
| "speedup": comparison['speedup'] | |
| } | |
| # Add applied filters | |
| if 'applied_filters' in comparison['hier_rag']: | |
| for key, val in comparison['hier_rag']['applied_filters'].items(): | |
| result[f"filter_{key}"] = val or "None" | |
| results.append(result) | |
| # Create DataFrame | |
| df = pd.DataFrame(results) | |
| # Save results | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| reports_dir = Path("./reports") | |
| reports_dir.mkdir(exist_ok=True) | |
| csv_path = reports_dir / f"evaluation_{timestamp}.csv" | |
| json_path = reports_dir / f"evaluation_{timestamp}.json" | |
| df.to_csv(csv_path, index=False) | |
| save_json(results, str(json_path)) | |
| # Generate visualizations | |
| progress(0.9, desc="Generating visualizations...") | |
| try: | |
| from core.eval_utils import generate_evaluation_report | |
| summary_stats = generate_evaluation_report(str(csv_path)) | |
| # Get the actual visualization file path | |
| visualization_path = str(csv_path).replace('.csv', '_report_charts.png') | |
| # Check if file exists | |
| if not Path(visualization_path).exists(): | |
| logger.warning(f"Visualization not generated: {visualization_path}") | |
| visualization_path = None | |
| except Exception as e: | |
| logger.error(f"Error generating visualization: {str(e)}") | |
| visualization_path = None | |
| summary_stats = { | |
| 'total_queries': len(df), | |
| 'avg_speedup': df['speedup'].mean(), | |
| 'median_speedup': df['speedup'].median(), | |
| 'max_speedup': df['speedup'].max(), | |
| 'min_speedup': df['speedup'].min(), | |
| 'hier_wins': (df['speedup'] > 1.0).sum(), | |
| 'win_rate': (df['speedup'] > 1.0).sum() / len(df) * 100, | |
| 'base_avg_total': df['base_total_time'].mean(), | |
| 'hier_avg_total': df['hier_total_time'].mean(), | |
| 'base_avg_retrieval': df['base_retrieval_time'].mean(), | |
| 'hier_avg_retrieval': df['hier_retrieval_time'].mean(), | |
| 'retrieval_improvement': (df['base_retrieval_time'].mean() - df['hier_retrieval_time'].mean()) / df['base_retrieval_time'].mean() * 100 | |
| } | |
| # Generate markdown summary | |
| summary_lines = [ | |
| f"# Evaluation Report ({timestamp})", | |
| f"\n## Configuration", | |
| f"- **Dataset**: {query_dataset}", | |
| f"- **Queries Evaluated**: {len(queries)}", | |
| f"- **K Values**: {k_values}", | |
| f"\n## Performance Summary", | |
| f"- **Average Speedup**: {summary_stats['avg_speedup']:.2f}x", | |
| f"- **Median Speedup**: {summary_stats['median_speedup']:.2f}x", | |
| f"- **Hier-RAG Win Rate**: {summary_stats['win_rate']:.1f}% ({summary_stats['hier_wins']}/{summary_stats['total_queries']} queries)", | |
| f"\n## Timing Results", | |
| f"### Base-RAG", | |
| f"- Avg Retrieval Time: {summary_stats['base_avg_retrieval']:.3f}s", | |
| f"- Avg Total Time: {summary_stats['base_avg_total']:.3f}s", | |
| f"\n### Hier-RAG", | |
| f"- Avg Retrieval Time: {summary_stats['hier_avg_retrieval']:.3f}s", | |
| f"- Avg Total Time: {summary_stats['hier_avg_total']:.3f}s", | |
| f"- **Retrieval Improvement**: {summary_stats['retrieval_improvement']:.1f}%", | |
| f"\n## Speed Analysis", | |
| f"- **Maximum Speedup**: {summary_stats['max_speedup']:.2f}x", | |
| f"- **Minimum Speedup**: {summary_stats['min_speedup']:.2f}x", | |
| ] | |
| if summary_stats['avg_speedup'] > 1.2: | |
| summary_lines.append(f"\nβ **Conclusion**: Hier-RAG shows **significant performance improvement** (>20% faster)") | |
| elif summary_stats['avg_speedup'] > 1.0: | |
| summary_lines.append(f"\nβ **Conclusion**: Hier-RAG shows **moderate performance improvement**") | |
| else: | |
| summary_lines.append(f"\nβ οΈ **Conclusion**: Hier-RAG needs optimization - filter inference may need improvement") | |
| summary_lines.extend([ | |
| f"\n## Output Files", | |
| f"- **CSV**: `{csv_path.name}`", | |
| f"- **JSON**: `{json_path.name}`", | |
| ]) | |
| if visualization_path and Path(visualization_path).exists(): | |
| summary_lines.append(f"- **Visualization**: `{Path(visualization_path).name}`") | |
| summary_lines.append(f"- **Detailed Report**: `{csv_path.stem}_report_summary.md`") | |
| else: | |
| summary_lines.append(f"- **Visualization**: Not generated (install matplotlib/seaborn)") | |
| summary = "\n".join(summary_lines) | |
| progress(1.0, desc="Complete!") | |
| return summary, str(csv_path), visualization_path | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"Error during evaluation: {str(e)}\n\n{traceback.format_exc()}" | |
| logger.error(error_msg) | |
| return error_msg, "", None | |
| # Add system health check function | |
| def system_health_check(): | |
| """Check if all components are working.""" | |
| checks = {} | |
| # Check 1: OpenAI API | |
| try: | |
| api_key = os.getenv("OPENAI_API_KEY") | |
| if not api_key: | |
| checks["π OpenAI API"] = "β API key not found" | |
| else: | |
| from openai import OpenAI | |
| client = OpenAI(api_key=api_key) | |
| client.models.list() | |
| checks["π OpenAI API"] = "β Connected and authenticated" | |
| except Exception as e: | |
| checks["π OpenAI API"] = f"β {str(e)[:50]}" | |
| # Check 2: Vector DB | |
| try: | |
| if index_manager: | |
| collections = index_manager.list_collections() | |
| checks[" Vector Database"] = f"β Initialized ({len(collections)} collections)" | |
| else: | |
| checks[" Vector Database"] = "β οΈ Not initialized yet" | |
| except Exception as e: | |
| checks[" Vector Database"] = f"β {str(e)[:50]}" | |
| # Check 3: Embedding Model | |
| try: | |
| from core.index import EmbeddingModel | |
| model = EmbeddingModel() | |
| test_embedding = model.embed_query("test") | |
| checks["π§ Embedding Model"] = f"β Loaded ({len(test_embedding)} dimensions)" | |
| except Exception as e: | |
| checks["π§ Embedding Model"] = f"β {str(e)[:50]}" | |
| # Check 4: RAG Pipelines | |
| try: | |
| if rag_comparator: | |
| checks[" RAG Pipelines"] = "β Base-RAG and Hier-RAG ready" | |
| else: | |
| checks[" RAG Pipelines"] = "β οΈ Not initialized (build index first)" | |
| except Exception as e: | |
| checks[" RAG Pipelines"] = f"β {str(e)[:50]}" | |
| # Check 5: Disk Space | |
| try: | |
| import shutil | |
| persist_dir = os.getenv("VECTOR_DB_PATH", "./data/chroma") | |
| if Path(persist_dir).exists(): | |
| total, used, free = shutil.disk_usage(persist_dir) | |
| free_gb = free // (2**30) | |
| checks[" Disk Space"] = f"β {free_gb} GB free" | |
| else: | |
| checks[" Disk Space"] = "β οΈ Vector DB path not created yet" | |
| except Exception as e: | |
| checks[" Disk Space"] = f"β {str(e)[:50]}" | |
| # Check 6: Environment Variables | |
| env_vars = ["OPENAI_API_KEY", "VECTOR_DB_PATH", "EMBEDDING_MODEL", "LLM_MODEL"] | |
| missing = [var for var in env_vars if not os.getenv(var)] | |
| if missing: | |
| checks[" Environment"] = f"β οΈ Missing: {', '.join(missing)}" | |
| else: | |
| checks[" Environment"] = "β All variables set" | |
| return checks | |
| # Build Gradio Interface | |
| # Update the Gradio interface creation | |
| def create_interface(): | |
| """Create the Gradio interface.""" | |
| with gr.Blocks( | |
| title="Hierarchical RAG Evaluation", | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .gradio-container { | |
| font-family: 'Inter', sans-serif; | |
| } | |
| .gr-button-primary { | |
| background: linear-gradient(90deg, #667eea 0%, #764ba2 100%) !important; | |
| border: none !important; | |
| } | |
| .gr-button-primary:hover { | |
| transform: scale(1.02); | |
| box-shadow: 0 5px 15px rgba(102, 126, 234, 0.4); | |
| } | |
| """ | |
| ) as demo: | |
| gr.Markdown(""" | |
| # Hierarchical RAG Evaluation System | |
| Compare **Base-RAG** vs **Hier-RAG** performance on accuracy and speed. | |
| **Hier-RAG** uses hierarchical metadata filtering to reduce search space and improve retrieval speed. | |
| """) | |
| # Initialize button at the top | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| init_btn = gr.Button(" Initialize System", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| health_btn = gr.Button(" Health Check", size="lg") | |
| with gr.Row(): | |
| init_status = gr.Markdown(label="Status") | |
| health_status = gr.JSON(label="System Health", visible=False) | |
| init_btn.click( | |
| initialize_system, | |
| outputs=[init_status], | |
| api_name="initialize" | |
| ) | |
| health_btn.click( | |
| system_health_check, | |
| outputs=[health_status], | |
| api_name="health_check" | |
| ).then( | |
| lambda: gr.update(visible=True), | |
| outputs=[health_status] | |
| ) | |
| with gr.Tabs(): | |
| # Tab 1: Upload Documents | |
| with gr.Tab("1οΈβ£ Upload Documents"): | |
| gr.Markdown(""" | |
| ### Upload Documents | |
| Upload PDF or TXT files to build your RAG system. | |
| **Supported formats:** `.pdf`, `.txt` | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| file_upload = gr.File( | |
| label=" Select Documents", | |
| file_count="multiple", | |
| file_types=[".pdf", ".txt"] | |
| ) | |
| hierarchy_choice = gr.Dropdown( | |
| choices=["hospital", "bank", "fluid_simulation"], | |
| value="hospital", | |
| label=" Select Hierarchy", | |
| info="Choose the domain that best matches your documents" | |
| ) | |
| mask_pii_check = gr.Checkbox( | |
| label=" Mask PII (Personally Identifiable Information)", | |
| value=False, | |
| info="Automatically mask emails, phone numbers, SSN" | |
| ) | |
| upload_btn = gr.Button("β Validate Upload", variant="primary") | |
| with gr.Column(): | |
| upload_status = gr.Textbox( | |
| label=" Upload Status", | |
| interactive=False, | |
| placeholder="Upload files to see validation results..." | |
| ) | |
| upload_preview = gr.Textbox( | |
| label=" Preview", | |
| lines=10, | |
| interactive=False, | |
| placeholder="File details will appear here..." | |
| ) | |
| upload_stats = gr.JSON(label=" Statistics") | |
| upload_btn.click( | |
| upload_documents, | |
| inputs=[file_upload, hierarchy_choice, mask_pii_check], | |
| outputs=[upload_status, upload_preview, upload_stats], | |
| api_name="upload" | |
| ) | |
| # Tab 2: Build RAG Index | |
| with gr.Tab("2οΈβ£ Build RAG"): | |
| gr.Markdown(""" | |
| ### Build Vector Index | |
| Process documents and create searchable vector database. | |
| **This may take a few minutes for large documents.** | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| build_files = gr.File( | |
| label=" Select Files to Index", | |
| file_count="multiple", | |
| file_types=[".pdf", ".txt"] | |
| ) | |
| build_hierarchy = gr.Dropdown( | |
| choices=["hospital", "bank", "fluid_simulation"], | |
| value="hospital", | |
| label=" Hierarchy" | |
| ) | |
| with gr.Accordion(" Advanced Options", open=False): | |
| chunk_size = gr.Slider( | |
| minimum=128, | |
| maximum=1024, | |
| value=512, | |
| step=64, | |
| label="π Chunk Size (tokens)", | |
| info="Larger chunks = more context, slower retrieval" | |
| ) | |
| chunk_overlap = gr.Slider( | |
| minimum=0, | |
| maximum=200, | |
| value=50, | |
| step=10, | |
| label="π Chunk Overlap (tokens)", | |
| info="Overlap helps maintain context across chunks" | |
| ) | |
| build_mask_pii = gr.Checkbox( | |
| label=" Mask PII", | |
| value=False | |
| ) | |
| use_llm_classification = gr.Checkbox( | |
| label=" Use LLM for Classification (Recommended)", | |
| value=True, | |
| info="More accurate but slower. Disable for faster processing." | |
| ) | |
| collection_name = gr.Textbox( | |
| label=" Collection Name", | |
| value="rag_documents", | |
| info="Name for this document collection" | |
| ) | |
| build_btn = gr.Button(" Build Index", variant="primary", size="lg") | |
| with gr.Column(): | |
| build_status = gr.Markdown( | |
| label="Status", | |
| value="Click 'Build Index' to start processing..." | |
| ) | |
| build_stats = gr.JSON(label=" Index Statistics") | |
| build_btn.click( | |
| build_rag_index, | |
| inputs=[ | |
| build_files, | |
| build_hierarchy, | |
| chunk_size, | |
| chunk_overlap, | |
| build_mask_pii, | |
| collection_name, | |
| use_llm_classification | |
| ], | |
| outputs=[build_status, build_stats], | |
| api_name="build" | |
| ) | |
| # Tab 3: Search | |
| with gr.Tab("3οΈβ£ Search"): | |
| gr.Markdown(""" | |
| ### Query the RAG System | |
| Test your queries and compare Base-RAG vs Hier-RAG performance. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| search_query = gr.Textbox( | |
| label=" Query", | |
| placeholder="e.g., What are the patient admission procedures?", | |
| lines=3 | |
| ) | |
| search_pipeline = gr.Radio( | |
| choices=["Base-RAG", "Hier-RAG", "Both"], | |
| value="Both", | |
| label=" Pipeline Selection", | |
| info="'Both' compares performance side-by-side" | |
| ) | |
| search_n_results = gr.Slider( | |
| minimum=1, | |
| maximum=20, | |
| value=5, | |
| step=1, | |
| label=" Number of Results" | |
| ) | |
| with gr.Accordion(" Hierarchical Filters (Hier-RAG only)", open=False): | |
| gr.Markdown("*Leave empty for auto-inference*") | |
| filter_level1 = gr.Textbox( | |
| label="Level 1 (Domain)", | |
| placeholder="e.g., Clinical Care" | |
| ) | |
| filter_level2 = gr.Textbox( | |
| label="Level 2 (Section)", | |
| placeholder="e.g., Patient Records" | |
| ) | |
| filter_level3 = gr.Textbox( | |
| label="Level 3 (Topic)", | |
| placeholder="e.g., Admission Notes" | |
| ) | |
| filter_doc_type = gr.Textbox( | |
| label="Document Type", | |
| placeholder="e.g., policy, manual, protocol" | |
| ) | |
| filter_auto_infer = gr.Checkbox( | |
| label=" Auto-infer filters from query", | |
| value=True, | |
| info="Uses LLM to automatically detect appropriate filters" | |
| ) | |
| search_btn = gr.Button("π Search", variant="primary", size="lg") | |
| with gr.Column(): | |
| search_answer = gr.Markdown(label="π‘ Answer") | |
| with gr.Accordion(" Retrieved Contexts", open=False): | |
| search_contexts = gr.Textbox( | |
| label="Contexts", | |
| lines=8, | |
| interactive=False | |
| ) | |
| with gr.Accordion("β± Performance Metrics", open=True): | |
| search_metadata = gr.Textbox( | |
| label="Metadata & Timing", | |
| lines=8, | |
| interactive=False | |
| ) | |
| search_btn.click( | |
| search_rag, | |
| inputs=[ | |
| search_query, | |
| search_pipeline, | |
| search_n_results, | |
| filter_level1, | |
| filter_level2, | |
| filter_level3, | |
| filter_doc_type, | |
| filter_auto_infer | |
| ], | |
| outputs=[search_answer, search_contexts, search_metadata], | |
| api_name="search" | |
| ) | |
| # Tab 4: Chat | |
| with gr.Tab("4οΈβ£ Chat"): | |
| gr.Markdown(""" | |
| ### Conversational Interface | |
| Have a conversation with your documents. Sources are shown for each answer. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| chatbot = gr.Chatbot( | |
| label="Chat History", | |
| height=500, | |
| avatar_images=(None, "π€") | |
| ) | |
| with gr.Row(): | |
| chat_input = gr.Textbox( | |
| label="Message", | |
| placeholder="Ask a question about your documents...", | |
| lines=2, | |
| scale=4 | |
| ) | |
| with gr.Row(): | |
| chat_submit = gr.Button(" Send", variant="primary", scale=3) | |
| chat_clear = gr.Button(" Clear", scale=1) | |
| with gr.Column(scale=1): | |
| chat_pipeline = gr.Radio( | |
| choices=["Base-RAG", "Hier-RAG"], | |
| value="Hier-RAG", | |
| label=" Pipeline" | |
| ) | |
| chat_n_results = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| value=3, | |
| step=1, | |
| label=" Context Documents" | |
| ) | |
| chat_sources = gr.Textbox( | |
| label=" Sources", | |
| lines=15, | |
| interactive=False, | |
| placeholder="Sources will appear here after you ask a question..." | |
| ) | |
| chat_submit.click( | |
| chat_interface, | |
| inputs=[chat_input, chatbot, chat_pipeline, chat_n_results], | |
| outputs=[chatbot, chat_sources], | |
| api_name="chat" | |
| ).then( | |
| lambda: "", | |
| outputs=[chat_input] | |
| ) | |
| chat_input.submit( | |
| chat_interface, | |
| inputs=[chat_input, chatbot, chat_pipeline, chat_n_results], | |
| outputs=[chatbot, chat_sources], | |
| api_name="chat_submit" | |
| ).then( | |
| lambda: "", | |
| outputs=[chat_input] | |
| ) | |
| chat_clear.click( | |
| lambda: ([], ""), | |
| outputs=[chatbot, chat_sources], | |
| api_name="clear_chat" | |
| ) | |
| # Tab 5: Evaluate | |
| with gr.Tab("5οΈβ£ Evaluate"): | |
| gr.Markdown(""" | |
| ### Quantitative Evaluation | |
| Run systematic evaluation to compare Base-RAG vs Hier-RAG performance. | |
| **Metrics computed:** Hit@k, MRR, Precision, Recall, Latency, Speedup | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| eval_dataset = gr.Dropdown( | |
| choices=["hospital", "bank", "fluid_simulation"], | |
| value="hospital", | |
| label=" Query Dataset" | |
| ) | |
| eval_n_queries = gr.Slider( | |
| minimum=1, | |
| maximum=50, | |
| value=10, | |
| step=1, | |
| label=" Number of Queries" | |
| ) | |
| eval_k_values = gr.Textbox( | |
| label="K Values (comma-separated)", | |
| value="1,3,5", | |
| placeholder="1,3,5", | |
| info="For Hit@k, Precision@k, Recall@k metrics" | |
| ) | |
| eval_btn = gr.Button(" Run Evaluation", variant="primary", size="lg") | |
| with gr.Column(): | |
| eval_summary = gr.Markdown( | |
| label="Summary", | |
| value="Click 'Run Evaluation' to start..." | |
| ) | |
| eval_csv = gr.Textbox( | |
| label=" CSV Output Path", | |
| interactive=False | |
| ) | |
| eval_visualization = gr.Image( | |
| label=" Performance Visualization", | |
| type="filepath" | |
| ) | |
| eval_btn.click( | |
| run_evaluation, | |
| inputs=[eval_dataset, eval_n_queries, eval_k_values], | |
| outputs=[eval_summary, eval_csv, eval_visualization], | |
| api_name="evaluate" | |
| ) | |
| # Footer | |
| gr.Markdown(""" | |
| --- | |
| ### Quick Reference | |
| | Pipeline | Description | Best For | | |
| |----------|-------------|----------| | |
| | **Base-RAG** | Standard vector similarity search | General queries, exploratory search | | |
| | **Hier-RAG** | Hierarchical filtering + vector search | Domain-specific queries, large document sets | | |
| **Tips:** | |
| - Use **Hier-RAG** when you know the domain/section of your query | |
| - Use **Both** to compare performance | |
| - Enable **LLM Classification** for best accuracy | |
| - Run **Evaluate** to get quantitative metrics | |
| ---""") | |
| # **Need help?** Check the [Documentation](README.md) or report issues on [GitHub](https://github.com/your-repo) | |
| # Built with β€οΈ using [Gradio](https://gradio.app) | Powered by [OpenAI](https://openai.com) & [ChromaDB](https://trychroma.com) | |
| # | |
| return demo | |
| # Launch the app | |
| # Launch the app | |
| if __name__ == "__main__": | |
| # Initialize on startup | |
| try: | |
| initialize_system() | |
| except Exception as e: | |
| logger.error(f"Startup initialization failed: {str(e)}") | |
| print("β οΈ Warning: System initialization failed. You can initialize manually from the UI.") | |
| # Create and launch interface | |
| demo = create_interface() | |
| demo.queue() # Enable queueing for better handling of concurrent requests | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| show_error=True, | |
| max_threads=10 | |
| ) |