import gradio as gr import json import os import asyncio import shutil import tempfile from typing import Dict, Any, Optional, List, Union import logging from datetime import datetime # LLM integrations import groq # Import the RAG system from invoice_rag_system import InvoiceRAGSystem # Setup logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger("invoice-rag-gradio") def setup_environment(): """Setup environment for HF Spaces""" # Set default paths for HF Spaces if not os.path.exists("sample_invoices"): os.makedirs("sample_invoices") # Check for HF Spaces environment if os.getenv("SPACE_ID"): print(f"🚀 Running on Hugging Face Spaces: {os.getenv('SPACE_ID')}") return True class LLMManager: """Manage different LLM providers""" def __init__(self): self.providers = { "groq": { "client": None, "models": ["llama-3.3-70b-versatile", "mixtral-8x7b-32768", "llama-3.1-8b-instant"], "api_key_env": "GROQ_API_KEY" }, } self.initialize_clients() def initialize_clients(self): """Initialize LLM clients based on available API keys""" # Groq if os.getenv(self.providers["groq"]["api_key_env"]): try: self.providers["groq"]["client"] = groq.Client( api_key=os.getenv(self.providers["groq"]["api_key_env"]) ) logger.info("Groq client initialized") except Exception as e: logger.error(f"Failed to initialize Groq client: {e}") def get_available_providers(self) -> List[str]: """Get list of available providers""" return [provider for provider, config in self.providers.items() if config["client"] is not None] def get_models_for_provider(self, provider: str) -> List[str]: """Get available models for a provider""" if provider in self.providers and self.providers[provider]["client"]: return self.providers[provider]["models"] return [] def generate_response(self, provider: str, model: str, prompt: str, max_tokens: int = 4096, temperature: float = 0.7) -> str: """Generate response using specified provider and model""" try: if provider == "groq": response = self.providers[provider]["client"].chat.completions.create( messages=[{"role": "user", "content": prompt}], model=model, max_tokens=max_tokens, temperature=temperature, ) return response.choices[0].message.content.strip() else: return f"Error: Provider {provider} not supported or not initialized" except Exception as e: logger.error(f"Error generating response with {provider}/{model}: {e}") return f"Error: {str(e)}" class InvoiceRAGInterface: """Gradio interface for Invoice RAG system with built-in API""" def __init__(self): setup_environment() self.rag_system = InvoiceRAGSystem() self.llm_manager = LLMManager() self.is_trained = False self.training_history = [] self.temp_upload_dir = tempfile.mkdtemp() # API Functions (exposed via Gradio's built-in API) def api_query_invoice_info(self, query: str, context_sections: str = None) -> str: """Extract information from invoices using the RAG system. Args: query: The question to ask about the invoices context_sections: Comma-separated list of sections to focus on (header,vendor,client,items,totals,footer) Returns: Extracted information and patterns from the invoice data """ if not self.is_trained: return json.dumps({"error": "RAG system not trained. Please train the system first with invoice PDFs."}) if not query.strip(): return json.dumps({"error": "Please provide a query"}) try: # Parse context sections sections = None if context_sections: sections = [s.strip() for s in context_sections.split(',') if s.strip()] # Extract information using RAG rag_results = self.rag_system.extract_invoice_info(query, sections) # Format response response = { "success": True, "query": query, "sources_found": rag_results['num_sources'], "chunks_retrieved": len(rag_results['context_chunks']), "extracted_patterns": rag_results['extracted_patterns'], "relevant_chunks": [ { "source": chunk['source'], "type": chunk['type'], "content": chunk['content'][:500] + "..." if len(chunk['content']) > 500 else chunk['content'], "relevance_score": chunk['score'] } for chunk in rag_results['context_chunks'][:5] ] } return json.dumps(response, indent=2, ensure_ascii=False) except Exception as e: logger.error(f"API Query error: {e}") return json.dumps({"error": f"Query failed: {str(e)}"}) def api_get_invoice_summary(self) -> str: """Get a summary of all processed invoices and their patterns.""" if not self.is_trained: return json.dumps({"error": "RAG system not trained. Please train the system first with invoice PDFs."}) try: summary = self.rag_system.get_pattern_summary() return json.dumps({"success": True, "summary": summary}, indent=2, ensure_ascii=False) except Exception as e: return json.dumps({"error": f"Failed to get summary: {str(e)}"}) def api_extract_specific_field(self, field_name: str, invoice_source: str = None) -> str: """Extract a specific field from invoices. Args: field_name: The field to extract (e.g., 'invoice_number', 'total', 'vendor_name') invoice_source: Optional specific invoice to search in """ if not self.is_trained: return json.dumps({"error": "RAG system not trained. Please train the system first with invoice PDFs."}) try: query = f"Find all {field_name} values" if invoice_source: query += f" from {invoice_source}" rag_results = self.rag_system.extract_invoice_info(query) # Extract the specific field from patterns field_values = [] for pattern in rag_results['extracted_patterns']: if field_name.lower() in str(pattern).lower(): field_values.append(pattern) result = { "success": True, "field": field_name, "values_found": len(field_values), "values": field_values, "source_invoices": rag_results['num_sources'] } return json.dumps(result, indent=2, ensure_ascii=False) except Exception as e: return json.dumps({"error": f"Field extraction failed: {str(e)}"}) def api_list_available_invoices(self) -> str: """List all available invoices in the RAG system.""" if not self.is_trained: return json.dumps({"error": "RAG system not trained. Please train the system first with invoice PDFs."}) try: # Get unique sources from chunks sources = set() chunk_counts = {} for chunk in self.rag_system.chunks: source = chunk.source_file sources.add(source) chunk_counts[source] = chunk_counts.get(source, 0) + 1 result = { "success": True, "total_invoices": len(sources), "total_chunks": len(self.rag_system.chunks), "invoices": [ { "source": source, "chunks": chunk_counts.get(source, 0) } for source in sorted(sources) ] } return json.dumps(result, indent=2, ensure_ascii=False) except Exception as e: return json.dumps({"error": f"Failed to list invoices: {str(e)}"}) def api_upload_and_train(self, files: List[str]) -> str: """Upload invoices and train the RAG system. Args: files: List of file paths to invoice PDFs """ try: if not files: return json.dumps({"error": "No files provided"}) # Create a temporary directory for this training session training_dir = tempfile.mkdtemp() # Copy uploaded files to training directory pdf_count = 0 for file_path in files: if file_path and os.path.exists(file_path) and file_path.lower().endswith('.pdf'): filename = os.path.basename(file_path) shutil.copy2(file_path, os.path.join(training_dir, filename)) pdf_count += 1 if pdf_count == 0: return json.dumps({"error": "No valid PDF files found"}) # Train the system self.rag_system.train_on_invoices(training_dir) self.is_trained = True # Get summary summary = self.rag_system.get_pattern_summary() # Update training history self.training_history.append({ 'timestamp': datetime.now().isoformat(), 'method': 'upload_and_train', 'num_invoices': summary['total_invoices'], 'num_chunks': summary['total_chunks'] }) # Clean up temporary directory shutil.rmtree(training_dir) result = { "success": True, "message": f"Training completed successfully! Processed {pdf_count} PDF files.", "invoices_processed": summary['total_invoices'], "chunks_created": summary['total_chunks'], "summary": summary } return json.dumps(result, indent=2, ensure_ascii=False) except Exception as e: logger.error(f"Upload and train error: {e}") return json.dumps({"error": f"Training failed: {str(e)}"}) # Regular Interface Functions def upload_and_train_files(self, files, progress=gr.Progress()) -> tuple: """Handle file upload and training""" if not files: return "❌ No files uploaded", "", "" try: progress(0, desc="Processing uploaded files...") # Filter PDF files pdf_files = [f for f in files if f.name.lower().endswith('.pdf')] if not pdf_files: return "❌ No PDF files found in upload", "", "" progress(0.2, desc=f"Found {len(pdf_files)} PDF files") # Create temporary directory and copy files training_dir = tempfile.mkdtemp() for pdf_file in pdf_files: filename = os.path.basename(pdf_file.name) shutil.copy2(pdf_file.name, os.path.join(training_dir, filename)) progress(0.4, desc="Training RAG system...") # Train the system self.rag_system.train_on_invoices(training_dir) progress(0.8, desc="Building index...") self.is_trained = True # Get summary summary = self.rag_system.get_pattern_summary() progress(1.0, desc="Training complete!") # Update training history self.training_history.append({ 'timestamp': datetime.now().isoformat(), 'method': 'file_upload', 'num_invoices': summary['total_invoices'], 'num_chunks': summary['total_chunks'] }) # Clean up shutil.rmtree(training_dir) status = f"✅ Training completed successfully!\n" \ f"📁 Processed {len(pdf_files)} PDF files\n" \ f"📄 Created {summary['total_chunks']} chunks\n" \ f"🚀 API endpoints are now available!" summary_text = json.dumps(summary, indent=2, ensure_ascii=False) return status, summary_text, self.format_training_history() except Exception as e: logger.error(f"Upload training error: {e}") return f"❌ Training failed: {str(e)}", "", "" def train_rag_system(self, invoice_folder: str, progress=gr.Progress()) -> tuple: """Train the RAG system on invoice folder""" if not invoice_folder or not os.path.exists(invoice_folder): return "❌ Invalid folder path", "", "" try: progress(0, desc="Starting training...") # Count PDF files pdf_files = [f for f in os.listdir(invoice_folder) if f.endswith('.pdf')] if not pdf_files: return "❌ No PDF files found in folder", "", "" progress(0.2, desc=f"Found {len(pdf_files)} PDF files") # Train the system self.rag_system.train_on_invoices(invoice_folder) progress(0.8, desc="Building index...") self.is_trained = True # Get summary summary = self.rag_system.get_pattern_summary() progress(1.0, desc="Training complete!") # Update training history self.training_history.append({ 'timestamp': datetime.now().isoformat(), 'method': 'folder_training', 'folder': invoice_folder, 'num_invoices': summary['total_invoices'], 'num_chunks': summary['total_chunks'] }) status = f"✅ Training completed successfully!\n" \ f"📁 Processed {summary['total_invoices']} invoices\n" \ f"📄 Created {summary['total_chunks']} chunks\n" \ f"🚀 API endpoints are now available!" summary_text = json.dumps(summary, indent=2, ensure_ascii=False) return status, summary_text, self.format_training_history() except Exception as e: logger.error(f"Training error: {e}") return f"❌ Training failed: {str(e)}", "", "" def load_model(self, model_path: str) -> tuple: """Load a pre-trained model""" if not model_path or not os.path.exists(model_path): return "❌ Invalid model path", "", "" try: self.rag_system.load_model(model_path) self.is_trained = True summary = self.rag_system.get_pattern_summary() status = f"✅ Model loaded successfully!\n" \ f"📁 Loaded {summary['total_invoices']} invoices\n" \ f"📄 {summary['total_chunks']} chunks available\n" \ f"🚀 API endpoints are now available!" summary_text = json.dumps(summary, indent=2, ensure_ascii=False) return status, summary_text, self.format_training_history() except Exception as e: logger.error(f"Model loading error: {e}") return f"❌ Failed to load model: {str(e)}", "", "" def save_model(self, save_path: str) -> str: """Save the current model""" if not self.is_trained: return "❌ No trained model to save" if not save_path: return "❌ Please provide a save path" try: # Ensure .pkl extension if not save_path.endswith('.pkl'): save_path += '.pkl' self.rag_system.save_model(save_path) return f"✅ Model saved to {save_path}" except Exception as e: logger.error(f"Model saving error: {e}") return f"❌ Failed to save model: {str(e)}" def query_invoices(self, query: str, provider: str, model: str, context_sections: List[str], top_k: int, temperature: float, max_tokens: int) -> tuple: """Query the invoice RAG system""" if not self.is_trained: return "❌ RAG system not trained. Please train or load a model first.", "", "" if not query.strip(): return "❌ Please enter a query", "", "" if not provider or provider not in self.llm_manager.get_available_providers(): return "❌ Please select a valid LLM provider", "", "" try: # Extract information using RAG rag_results = self.rag_system.extract_invoice_info( query, context_sections if context_sections else None ) # Prepare context for LLM context_chunks = rag_results['context_chunks'][:top_k] context_text = "\n\n".join( f"[{chunk['type']}] From {chunk['source']}:\n{chunk['content']}" for chunk in context_chunks ) # Create prompt for LLM prompt = f"""Based on the following invoice data, please answer the user's question. Context from invoices: {context_text} Extracted patterns: {json.dumps(rag_results['extracted_patterns'], indent=2)} User question: {query} Please provide a detailed and accurate answer based on the invoice data provided. If you cannot find specific information in the context, please mention that.""" # Generate response using selected LLM llm_response = self.llm_manager.generate_response( provider, model, prompt, max_tokens, temperature ) # Format RAG context info rag_info = f"""**RAG Context Retrieved:** - Sources: {rag_results['num_sources']} invoices - Chunks: {len(context_chunks)} relevant sections - Sections: {', '.join(set(chunk['type'] for chunk in context_chunks))} **Top Retrieved Chunks:** """ for i, chunk in enumerate(context_chunks[:3], 1): rag_info += f"\n{i}. [{chunk['type']}] {chunk['source']} (Score: {chunk['score']:.3f})\n" rag_info += f" {chunk['content'][:200]}{'...' if len(chunk['content']) > 200 else ''}\n" return llm_response, rag_info, json.dumps(rag_results['extracted_patterns'], indent=2) except Exception as e: logger.error(f"Query error: {e}") return f"❌ Query failed: {str(e)}", "", "" def format_training_history(self) -> str: """Format training history for display""" if not self.training_history: return "No training history available" history = "**Training History:**\n\n" for i, entry in enumerate(reversed(self.training_history), 1): history += f"{i}. **{entry['timestamp'][:19]}**\n" history += f" 🔧 Method: {entry['method'].replace('_', ' ').title()}\n" if 'folder' in entry: history += f" 📁 Folder: {entry['folder']}\n" history += f" 📊 {entry['num_invoices']} invoices, {entry['num_chunks']} chunks\n\n" return history def get_system_status(self) -> str: """Get current system status""" available_providers = self.llm_manager.get_available_providers() status = f"""**System Status:** **RAG System:** - Trained: {'✅ Yes' if self.is_trained else '❌ No'} - Chunks: {len(self.rag_system.chunks) if self.is_trained else 0} - Index: {'✅ Built' if self.rag_system.index is not None else '❌ Not built'} **Gradio API:** - Status: {'✅ Active' if self.is_trained else '⏳ Waiting for training'} - Available Endpoints: {'4 endpoints ready' if self.is_trained else 'Training required'} **Available LLM Providers:** """ for provider in available_providers: models = self.llm_manager.get_models_for_provider(provider) status += f"- **{provider.upper()}**: {', '.join(models)}\n" if not available_providers: status += "❌ No LLM providers configured. Please set API keys.\n" return status def get_api_info(self) -> str: """Get API endpoint information""" if not self.is_trained: return "❌ API endpoints not available - RAG system not trained" api_endpoints = [ "🔍 `/api/query_invoice_info` - Extract information from invoices", "📋 `/api/get_invoice_summary` - Get summary of all processed invoices", "🔎 `/api/extract_specific_field` - Extract specific fields from invoices", "📄 `/api/list_available_invoices` - List all available invoice sources", "📤 `/api/upload_and_train` - Upload and train on new invoices" ] info = f"""**Gradio API Information:** **Available Endpoints:** {chr(10).join(api_endpoints)} **API Status:** ✅ Active **Endpoint Count:** {len(api_endpoints)} **Usage Examples:** **Python:** ```python import requests # Query invoices response = requests.post("http://localhost:7860/api/predict", json={{ "data": ["What are all invoice numbers?", "header,totals"], "fn_index": 0 # api_query_invoice_info function index }}) # Get summary response = requests.post("http://localhost:7860/api/predict", json={{ "data": [], "fn_index": 1 # api_get_invoice_summary function index }}) ``` **cURL:** ```bash # Query invoices curl -X POST "http://localhost:7860/api/predict" \\ -H "Content-Type: application/json" \\ -d '{{"data": ["Extract vendor information", "vendor"], "fn_index": 0}}' # Get invoice summary curl -X POST "http://localhost:7860/api/predict" \\ -H "Content-Type: application/json" \\ -d '{{"data": [], "fn_index": 1}}' ``` **Base URL:** `http://localhost:7860` **API Documentation:** Available at `http://localhost:7860/docs` """ return info def create_interface(self): """Create the Gradio interface with built-in API support""" with gr.Blocks(title="Invoice RAG System with API", theme=gr.themes.Soft()) as demo: gr.Markdown("# 📄 Invoice RAG System with Gradio API") gr.Markdown("Train on invoice PDFs and query them using different language models or API endpoints") with gr.Tabs(): # Training Tab with gr.TabItem("🎯 Training"): gr.Markdown("## Train RAG Model") with gr.Row(): with gr.Column(): gr.Markdown("### 📤 Upload Invoice PDFs") upload_files = gr.File( label="Upload Invoice PDFs", file_count="multiple", file_types=[".pdf"], height=200 ) upload_train_btn = gr.Button("🚀 Upload & Train", variant="primary") with gr.Column(): gr.Markdown("### 📁 Train from Folder") invoice_folder = gr.Textbox( label="Invoice Folder Path", placeholder="Path to folder containing PDF invoices" ) folder_train_btn = gr.Button("🚀 Train from Folder", variant="secondary") training_status = gr.Textbox( label="Training Status", interactive=False, lines=4 ) with gr.Row(): with gr.Column(): summary_output = gr.Code( label="Pattern Summary", language="json", lines=10 ) with gr.Column(): history_output = gr.Markdown( label="Training History" ) gr.Markdown("### 💾 Save/Load Model") with gr.Row(): with gr.Column(): save_path = gr.Textbox( label="Save Path", placeholder="model_name.pkl" ) save_btn = gr.Button("💾 Save Model") save_status = gr.Textbox( label="Save Status", interactive=False ) with gr.Column(): model_path = gr.Textbox( label="Model Path", placeholder="Path to saved model (.pkl)" ) load_btn = gr.Button("📥 Load Model") # Query Tab with gr.TabItem("🔍 Query"): gr.Markdown("## Query Invoice Data") with gr.Row(): with gr.Column(scale=2): query_input = gr.Textbox( label="Your Question", placeholder="What are the invoice numbers?", lines=2 ) provider_dropdown = gr.Dropdown( choices=self.llm_manager.get_available_providers(), label="LLM Provider", value=self.llm_manager.get_available_providers()[0] if self.llm_manager.get_available_providers() else None ) model_dropdown = gr.Dropdown( label="Model", choices=self.llm_manager.get_models_for_provider( self.llm_manager.get_available_providers()[0] if self.llm_manager.get_available_providers() else "" ) if self.llm_manager.get_available_providers() else [] ) with gr.Column(scale=1): context_sections = gr.CheckboxGroup( choices=["header", "vendor", "client", "items", "totals", "footer"], label="Context Sections", info="Leave empty for all sections" ) top_k = gr.Slider( minimum=1, maximum=20, value=5, step=1, label="Top K Results" ) temperature = gr.Slider( minimum=0.0, maximum=2.0, value=0.7, step=0.1, label="Temperature" ) max_tokens = gr.Slider( minimum=100, maximum=8192, value=4096, step=100, label="Max Tokens" ) query_btn = gr.Button("🤖 Query RAG System", variant="primary") with gr.Row(): with gr.Column(): llm_response = gr.Textbox( label="LLM Response", lines=10, interactive=False ) with gr.Column(): rag_context = gr.Markdown( label="RAG Context" ) patterns_output = gr.Code( label="Extracted Patterns", language="json", lines=5 ) # API Tools Tab with gr.TabItem("🔧 API Tools"): gr.Markdown("## Test API Functions Directly") gr.Markdown("These functions are exposed via Gradio's built-in API system") with gr.Row(): with gr.Column(): gr.Markdown("### Query Invoice Info") api_query = gr.Textbox( label="Query", placeholder="What are all the invoice numbers?" ) api_sections = gr.Textbox( label="Context Sections (comma-separated)", placeholder="header,vendor,totals", info="Optional: specify which sections to focus on" ) api_query_btn = gr.Button("🔍 Run API Query") api_query_output = gr.Code(language="json", lines=8) with gr.Column(): gr.Markdown("### Extract Specific Field") field_name = gr.Textbox( label="Field Name", placeholder="invoice_number, total, vendor_name" ) invoice_source = gr.Textbox( label="Invoice Source (optional)", placeholder="Leave empty to search all invoices" ) extract_btn = gr.Button("🔎 Extract Field") extract_output = gr.Code(language="json", lines=8) with gr.Row(): with gr.Column(): summary_btn = gr.Button("📋 Get Invoice Summary") summary_api_output = gr.Code(language="json", lines=6) with gr.Column(): list_btn = gr.Button("📄 List Available Invoices") list_output = gr.Code(language="json", lines=6) # Status Tab with gr.TabItem("📊 Status & API"): gr.Markdown("## System Status & API Information") with gr.Row(): status_btn = gr.Button("🔄 Refresh Status") mcp_info_btn = gr.Button("🚀 Get API Info") with gr.Row(): with gr.Column(): status_output = gr.Markdown() with gr.Column(): mcp_info_output = gr.Markdown() # Predefined queries gr.Markdown("## 📝 Example Queries") example_queries = gr.Examples( examples=[ ["What are all the invoice numbers?"], ["Show me vendor information"], ["Extract total amounts from all invoices"], ["Find products with quantities and prices"], ["What are the invoice dates?"], ["List all companies mentioned in the invoices"], ["What payment terms are mentioned?"], ["Extract line items with descriptions and amounts"] ], inputs=[query_input], label="Click to use example queries" ) # Event handlers def update_models(provider): if provider: return gr.Dropdown(choices=self.llm_manager.get_models_for_provider(provider)) return gr.Dropdown(choices=[]) provider_dropdown.change( update_models, inputs=[provider_dropdown], outputs=[model_dropdown] ) upload_train_btn.click( self.upload_and_train_files, inputs=[upload_files], outputs=[training_status, summary_output, history_output] ) folder_train_btn.click( self.train_rag_system, inputs=[invoice_folder], outputs=[training_status, summary_output, history_output] ) load_btn.click( self.load_model, inputs=[model_path], outputs=[training_status, summary_output, history_output] ) save_btn.click( self.save_model, inputs=[save_path], outputs=[save_status] ) query_btn.click( self.query_invoices, inputs=[ query_input, provider_dropdown, model_dropdown, context_sections, top_k, temperature, max_tokens ], outputs=[llm_response, rag_context, patterns_output] ) # MCP Tool handlers api_query_btn.click( self.api_query_invoice_info, inputs=[api_query, api_sections], outputs=[api_query_output] ) extract_btn.click( self.api_extract_specific_field, inputs=[field_name, invoice_source], outputs=[extract_output] ) summary_btn.click( self.api_get_invoice_summary, outputs=[summary_api_output] ) list_btn.click( self.api_get_invoice_summary, outputs=[list_output] ) status_btn.click( self.get_system_status, outputs=[status_output] ) mcp_info_btn.click( self.get_api_info, outputs=[mcp_info_output] ) # Initialize status on load demo.load( lambda: (self.get_system_status(), self.get_api_info()), outputs=[status_output, mcp_info_output] ) return demo def main(): """Main function optimized for HF Spaces""" # Setup setup_environment() # Check API keys with HF Spaces support required_vars = { "GROQ_API_KEY": "Groq API", } available_apis = [] for var, name in required_vars.items(): # Check both environment and HF Spaces secrets if os.getenv(var) or os.getenv(f"HF_{var}"): available_apis.append(name) # Use HF secret if available if os.getenv(f"HF_{var}") and not os.getenv(var): os.environ[var] = os.getenv(f"HF_{var}") if not available_apis: print("⚠️ Warning: No API keys found.") print("Set GROQ_API_KEY in HF Spaces secrets or environment") # Create interface interface = InvoiceRAGInterface() demo = interface.create_interface() print("🚀 Starting Invoice RAG System on Hugging Face Spaces...") # HF Spaces optimized launch demo.launch( server_name="0.0.0.0", # Listen on all network interfaces server_port=7860, # Default Gradio port share=True, # Enable sharing debug=False, # Disable debug mode in production auth=None, # No authentication required show_api=True, # Show API documentation max_threads=40, # Limit concurrent threads ) if __name__ == "__main__": main()