Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import json | |
| import os | |
| import asyncio | |
| import shutil | |
| import tempfile | |
| from typing import Dict, Any, Optional, List, Union | |
| import logging | |
| from datetime import datetime | |
| # LLM integrations | |
| import groq | |
| # Import the RAG system | |
| from invoice_rag_system import InvoiceRAGSystem | |
| # Setup logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("invoice-rag-gradio") | |
| def setup_environment(): | |
| """Setup environment for HF Spaces""" | |
| # Set default paths for HF Spaces | |
| if not os.path.exists("sample_invoices"): | |
| os.makedirs("sample_invoices") | |
| # Check for HF Spaces environment | |
| if os.getenv("SPACE_ID"): | |
| print(f"๐ Running on Hugging Face Spaces: {os.getenv('SPACE_ID')}") | |
| return True | |
| class LLMManager: | |
| """Manage different LLM providers""" | |
| def __init__(self): | |
| self.providers = { | |
| "groq": { | |
| "client": None, | |
| "models": ["llama-3.3-70b-versatile", "mixtral-8x7b-32768", "llama-3.1-8b-instant"], | |
| "api_key_env": "GROQ_API_KEY" | |
| }, | |
| } | |
| self.initialize_clients() | |
| def initialize_clients(self): | |
| """Initialize LLM clients based on available API keys""" | |
| # Groq | |
| if os.getenv(self.providers["groq"]["api_key_env"]): | |
| try: | |
| self.providers["groq"]["client"] = groq.Client( | |
| api_key=os.getenv(self.providers["groq"]["api_key_env"]) | |
| ) | |
| logger.info("Groq client initialized") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize Groq client: {e}") | |
| def get_available_providers(self) -> List[str]: | |
| """Get list of available providers""" | |
| return [provider for provider, config in self.providers.items() | |
| if config["client"] is not None] | |
| def get_models_for_provider(self, provider: str) -> List[str]: | |
| """Get available models for a provider""" | |
| if provider in self.providers and self.providers[provider]["client"]: | |
| return self.providers[provider]["models"] | |
| return [] | |
| def generate_response(self, provider: str, model: str, prompt: str, | |
| max_tokens: int = 4096, temperature: float = 0.7) -> str: | |
| """Generate response using specified provider and model""" | |
| try: | |
| if provider == "groq": | |
| response = self.providers[provider]["client"].chat.completions.create( | |
| messages=[{"role": "user", "content": prompt}], | |
| model=model, | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| ) | |
| return response.choices[0].message.content.strip() | |
| else: | |
| return f"Error: Provider {provider} not supported or not initialized" | |
| except Exception as e: | |
| logger.error(f"Error generating response with {provider}/{model}: {e}") | |
| return f"Error: {str(e)}" | |
| class InvoiceRAGInterface: | |
| """Gradio interface for Invoice RAG system with built-in API""" | |
| def __init__(self): | |
| setup_environment() | |
| self.rag_system = InvoiceRAGSystem() | |
| self.llm_manager = LLMManager() | |
| self.is_trained = False | |
| self.training_history = [] | |
| self.temp_upload_dir = tempfile.mkdtemp() | |
| # API Functions (exposed via Gradio's built-in API) | |
| def api_query_invoice_info(self, query: str, context_sections: str = None) -> str: | |
| """Extract information from invoices using the RAG system. | |
| Args: | |
| query: The question to ask about the invoices | |
| context_sections: Comma-separated list of sections to focus on (header,vendor,client,items,totals,footer) | |
| Returns: | |
| Extracted information and patterns from the invoice data | |
| """ | |
| if not self.is_trained: | |
| return json.dumps({"error": "RAG system not trained. Please train the system first with invoice PDFs."}) | |
| if not query.strip(): | |
| return json.dumps({"error": "Please provide a query"}) | |
| try: | |
| # Parse context sections | |
| sections = None | |
| if context_sections: | |
| sections = [s.strip() for s in context_sections.split(',') if s.strip()] | |
| # Extract information using RAG | |
| rag_results = self.rag_system.extract_invoice_info(query, sections) | |
| # Format response | |
| response = { | |
| "success": True, | |
| "query": query, | |
| "sources_found": rag_results['num_sources'], | |
| "chunks_retrieved": len(rag_results['context_chunks']), | |
| "extracted_patterns": rag_results['extracted_patterns'], | |
| "relevant_chunks": [ | |
| { | |
| "source": chunk['source'], | |
| "type": chunk['type'], | |
| "content": chunk['content'][:500] + "..." if len(chunk['content']) > 500 else chunk['content'], | |
| "relevance_score": chunk['score'] | |
| } | |
| for chunk in rag_results['context_chunks'][:5] | |
| ] | |
| } | |
| return json.dumps(response, indent=2, ensure_ascii=False) | |
| except Exception as e: | |
| logger.error(f"API Query error: {e}") | |
| return json.dumps({"error": f"Query failed: {str(e)}"}) | |
| def api_get_invoice_summary(self) -> str: | |
| """Get a summary of all processed invoices and their patterns.""" | |
| if not self.is_trained: | |
| return json.dumps({"error": "RAG system not trained. Please train the system first with invoice PDFs."}) | |
| try: | |
| summary = self.rag_system.get_pattern_summary() | |
| return json.dumps({"success": True, "summary": summary}, indent=2, ensure_ascii=False) | |
| except Exception as e: | |
| return json.dumps({"error": f"Failed to get summary: {str(e)}"}) | |
| def api_extract_specific_field(self, field_name: str, invoice_source: str = None) -> str: | |
| """Extract a specific field from invoices. | |
| Args: | |
| field_name: The field to extract (e.g., 'invoice_number', 'total', 'vendor_name') | |
| invoice_source: Optional specific invoice to search in | |
| """ | |
| if not self.is_trained: | |
| return json.dumps({"error": "RAG system not trained. Please train the system first with invoice PDFs."}) | |
| try: | |
| query = f"Find all {field_name} values" | |
| if invoice_source: | |
| query += f" from {invoice_source}" | |
| rag_results = self.rag_system.extract_invoice_info(query) | |
| # Extract the specific field from patterns | |
| field_values = [] | |
| for pattern in rag_results['extracted_patterns']: | |
| if field_name.lower() in str(pattern).lower(): | |
| field_values.append(pattern) | |
| result = { | |
| "success": True, | |
| "field": field_name, | |
| "values_found": len(field_values), | |
| "values": field_values, | |
| "source_invoices": rag_results['num_sources'] | |
| } | |
| return json.dumps(result, indent=2, ensure_ascii=False) | |
| except Exception as e: | |
| return json.dumps({"error": f"Field extraction failed: {str(e)}"}) | |
| def api_list_available_invoices(self) -> str: | |
| """List all available invoices in the RAG system.""" | |
| if not self.is_trained: | |
| return json.dumps({"error": "RAG system not trained. Please train the system first with invoice PDFs."}) | |
| try: | |
| # Get unique sources from chunks | |
| sources = set() | |
| chunk_counts = {} | |
| for chunk in self.rag_system.chunks: | |
| source = chunk.source_file | |
| sources.add(source) | |
| chunk_counts[source] = chunk_counts.get(source, 0) + 1 | |
| result = { | |
| "success": True, | |
| "total_invoices": len(sources), | |
| "total_chunks": len(self.rag_system.chunks), | |
| "invoices": [ | |
| { | |
| "source": source, | |
| "chunks": chunk_counts.get(source, 0) | |
| } | |
| for source in sorted(sources) | |
| ] | |
| } | |
| return json.dumps(result, indent=2, ensure_ascii=False) | |
| except Exception as e: | |
| return json.dumps({"error": f"Failed to list invoices: {str(e)}"}) | |
| def api_upload_and_train(self, files: List[str]) -> str: | |
| """Upload invoices and train the RAG system. | |
| Args: | |
| files: List of file paths to invoice PDFs | |
| """ | |
| try: | |
| if not files: | |
| return json.dumps({"error": "No files provided"}) | |
| # Create a temporary directory for this training session | |
| training_dir = tempfile.mkdtemp() | |
| # Copy uploaded files to training directory | |
| pdf_count = 0 | |
| for file_path in files: | |
| if file_path and os.path.exists(file_path) and file_path.lower().endswith('.pdf'): | |
| filename = os.path.basename(file_path) | |
| shutil.copy2(file_path, os.path.join(training_dir, filename)) | |
| pdf_count += 1 | |
| if pdf_count == 0: | |
| return json.dumps({"error": "No valid PDF files found"}) | |
| # Train the system | |
| self.rag_system.train_on_invoices(training_dir) | |
| self.is_trained = True | |
| # Get summary | |
| summary = self.rag_system.get_pattern_summary() | |
| # Update training history | |
| self.training_history.append({ | |
| 'timestamp': datetime.now().isoformat(), | |
| 'method': 'upload_and_train', | |
| 'num_invoices': summary['total_invoices'], | |
| 'num_chunks': summary['total_chunks'] | |
| }) | |
| # Clean up temporary directory | |
| shutil.rmtree(training_dir) | |
| result = { | |
| "success": True, | |
| "message": f"Training completed successfully! Processed {pdf_count} PDF files.", | |
| "invoices_processed": summary['total_invoices'], | |
| "chunks_created": summary['total_chunks'], | |
| "summary": summary | |
| } | |
| return json.dumps(result, indent=2, ensure_ascii=False) | |
| except Exception as e: | |
| logger.error(f"Upload and train error: {e}") | |
| return json.dumps({"error": f"Training failed: {str(e)}"}) | |
| # Regular Interface Functions | |
| def upload_and_train_files(self, files, progress=gr.Progress()) -> tuple: | |
| """Handle file upload and training""" | |
| if not files: | |
| return "โ No files uploaded", "", "" | |
| try: | |
| progress(0, desc="Processing uploaded files...") | |
| # Filter PDF files | |
| pdf_files = [f for f in files if f.name.lower().endswith('.pdf')] | |
| if not pdf_files: | |
| return "โ No PDF files found in upload", "", "" | |
| progress(0.2, desc=f"Found {len(pdf_files)} PDF files") | |
| # Create temporary directory and copy files | |
| training_dir = tempfile.mkdtemp() | |
| for pdf_file in pdf_files: | |
| filename = os.path.basename(pdf_file.name) | |
| shutil.copy2(pdf_file.name, os.path.join(training_dir, filename)) | |
| progress(0.4, desc="Training RAG system...") | |
| # Train the system | |
| self.rag_system.train_on_invoices(training_dir) | |
| progress(0.8, desc="Building index...") | |
| self.is_trained = True | |
| # Get summary | |
| summary = self.rag_system.get_pattern_summary() | |
| progress(1.0, desc="Training complete!") | |
| # Update training history | |
| self.training_history.append({ | |
| 'timestamp': datetime.now().isoformat(), | |
| 'method': 'file_upload', | |
| 'num_invoices': summary['total_invoices'], | |
| 'num_chunks': summary['total_chunks'] | |
| }) | |
| # Clean up | |
| shutil.rmtree(training_dir) | |
| status = f"โ Training completed successfully!\n" \ | |
| f"๐ Processed {len(pdf_files)} PDF files\n" \ | |
| f"๐ Created {summary['total_chunks']} chunks\n" \ | |
| f"๐ API endpoints are now available!" | |
| summary_text = json.dumps(summary, indent=2, ensure_ascii=False) | |
| return status, summary_text, self.format_training_history() | |
| except Exception as e: | |
| logger.error(f"Upload training error: {e}") | |
| return f"โ Training failed: {str(e)}", "", "" | |
| def train_rag_system(self, invoice_folder: str, progress=gr.Progress()) -> tuple: | |
| """Train the RAG system on invoice folder""" | |
| if not invoice_folder or not os.path.exists(invoice_folder): | |
| return "โ Invalid folder path", "", "" | |
| try: | |
| progress(0, desc="Starting training...") | |
| # Count PDF files | |
| pdf_files = [f for f in os.listdir(invoice_folder) if f.endswith('.pdf')] | |
| if not pdf_files: | |
| return "โ No PDF files found in folder", "", "" | |
| progress(0.2, desc=f"Found {len(pdf_files)} PDF files") | |
| # Train the system | |
| self.rag_system.train_on_invoices(invoice_folder) | |
| progress(0.8, desc="Building index...") | |
| self.is_trained = True | |
| # Get summary | |
| summary = self.rag_system.get_pattern_summary() | |
| progress(1.0, desc="Training complete!") | |
| # Update training history | |
| self.training_history.append({ | |
| 'timestamp': datetime.now().isoformat(), | |
| 'method': 'folder_training', | |
| 'folder': invoice_folder, | |
| 'num_invoices': summary['total_invoices'], | |
| 'num_chunks': summary['total_chunks'] | |
| }) | |
| status = f"โ Training completed successfully!\n" \ | |
| f"๐ Processed {summary['total_invoices']} invoices\n" \ | |
| f"๐ Created {summary['total_chunks']} chunks\n" \ | |
| f"๐ API endpoints are now available!" | |
| summary_text = json.dumps(summary, indent=2, ensure_ascii=False) | |
| return status, summary_text, self.format_training_history() | |
| except Exception as e: | |
| logger.error(f"Training error: {e}") | |
| return f"โ Training failed: {str(e)}", "", "" | |
| def load_model(self, model_path: str) -> tuple: | |
| """Load a pre-trained model""" | |
| if not model_path or not os.path.exists(model_path): | |
| return "โ Invalid model path", "", "" | |
| try: | |
| self.rag_system.load_model(model_path) | |
| self.is_trained = True | |
| summary = self.rag_system.get_pattern_summary() | |
| status = f"โ Model loaded successfully!\n" \ | |
| f"๐ Loaded {summary['total_invoices']} invoices\n" \ | |
| f"๐ {summary['total_chunks']} chunks available\n" \ | |
| f"๐ API endpoints are now available!" | |
| summary_text = json.dumps(summary, indent=2, ensure_ascii=False) | |
| return status, summary_text, self.format_training_history() | |
| except Exception as e: | |
| logger.error(f"Model loading error: {e}") | |
| return f"โ Failed to load model: {str(e)}", "", "" | |
| def save_model(self, save_path: str) -> str: | |
| """Save the current model""" | |
| if not self.is_trained: | |
| return "โ No trained model to save" | |
| if not save_path: | |
| return "โ Please provide a save path" | |
| try: | |
| # Ensure .pkl extension | |
| if not save_path.endswith('.pkl'): | |
| save_path += '.pkl' | |
| self.rag_system.save_model(save_path) | |
| return f"โ Model saved to {save_path}" | |
| except Exception as e: | |
| logger.error(f"Model saving error: {e}") | |
| return f"โ Failed to save model: {str(e)}" | |
| def query_invoices(self, query: str, provider: str, model: str, | |
| context_sections: List[str], top_k: int, | |
| temperature: float, max_tokens: int) -> tuple: | |
| """Query the invoice RAG system""" | |
| if not self.is_trained: | |
| return "โ RAG system not trained. Please train or load a model first.", "", "" | |
| if not query.strip(): | |
| return "โ Please enter a query", "", "" | |
| if not provider or provider not in self.llm_manager.get_available_providers(): | |
| return "โ Please select a valid LLM provider", "", "" | |
| try: | |
| # Extract information using RAG | |
| rag_results = self.rag_system.extract_invoice_info( | |
| query, | |
| context_sections if context_sections else None | |
| ) | |
| # Prepare context for LLM | |
| context_chunks = rag_results['context_chunks'][:top_k] | |
| context_text = "\n\n".join( | |
| f"[{chunk['type']}] From {chunk['source']}:\n{chunk['content']}" | |
| for chunk in context_chunks | |
| ) | |
| # Create prompt for LLM | |
| prompt = f"""Based on the following invoice data, please answer the user's question. | |
| Context from invoices: | |
| {context_text} | |
| Extracted patterns: | |
| {json.dumps(rag_results['extracted_patterns'], indent=2)} | |
| User question: {query} | |
| Please provide a detailed and accurate answer based on the invoice data provided. If you cannot find specific information in the context, please mention that.""" | |
| # Generate response using selected LLM | |
| llm_response = self.llm_manager.generate_response( | |
| provider, model, prompt, max_tokens, temperature | |
| ) | |
| # Format RAG context info | |
| rag_info = f"""**RAG Context Retrieved:** | |
| - Sources: {rag_results['num_sources']} invoices | |
| - Chunks: {len(context_chunks)} relevant sections | |
| - Sections: {', '.join(set(chunk['type'] for chunk in context_chunks))} | |
| **Top Retrieved Chunks:** | |
| """ | |
| for i, chunk in enumerate(context_chunks[:3], 1): | |
| rag_info += f"\n{i}. [{chunk['type']}] {chunk['source']} (Score: {chunk['score']:.3f})\n" | |
| rag_info += f" {chunk['content'][:200]}{'...' if len(chunk['content']) > 200 else ''}\n" | |
| return llm_response, rag_info, json.dumps(rag_results['extracted_patterns'], indent=2) | |
| except Exception as e: | |
| logger.error(f"Query error: {e}") | |
| return f"โ Query failed: {str(e)}", "", "" | |
| def format_training_history(self) -> str: | |
| """Format training history for display""" | |
| if not self.training_history: | |
| return "No training history available" | |
| history = "**Training History:**\n\n" | |
| for i, entry in enumerate(reversed(self.training_history), 1): | |
| history += f"{i}. **{entry['timestamp'][:19]}**\n" | |
| history += f" ๐ง Method: {entry['method'].replace('_', ' ').title()}\n" | |
| if 'folder' in entry: | |
| history += f" ๐ Folder: {entry['folder']}\n" | |
| history += f" ๐ {entry['num_invoices']} invoices, {entry['num_chunks']} chunks\n\n" | |
| return history | |
| def get_system_status(self) -> str: | |
| """Get current system status""" | |
| available_providers = self.llm_manager.get_available_providers() | |
| status = f"""**System Status:** | |
| **RAG System:** | |
| - Trained: {'โ Yes' if self.is_trained else 'โ No'} | |
| - Chunks: {len(self.rag_system.chunks) if self.is_trained else 0} | |
| - Index: {'โ Built' if self.rag_system.index is not None else 'โ Not built'} | |
| **Gradio API:** | |
| - Status: {'โ Active' if self.is_trained else 'โณ Waiting for training'} | |
| - Available Endpoints: {'4 endpoints ready' if self.is_trained else 'Training required'} | |
| **Available LLM Providers:** | |
| """ | |
| for provider in available_providers: | |
| models = self.llm_manager.get_models_for_provider(provider) | |
| status += f"- **{provider.upper()}**: {', '.join(models)}\n" | |
| if not available_providers: | |
| status += "โ No LLM providers configured. Please set API keys.\n" | |
| return status | |
| def get_api_info(self) -> str: | |
| """Get API endpoint information""" | |
| if not self.is_trained: | |
| return "โ API endpoints not available - RAG system not trained" | |
| api_endpoints = [ | |
| "๐ `/api/query_invoice_info` - Extract information from invoices", | |
| "๐ `/api/get_invoice_summary` - Get summary of all processed invoices", | |
| "๐ `/api/extract_specific_field` - Extract specific fields from invoices", | |
| "๐ `/api/list_available_invoices` - List all available invoice sources", | |
| "๐ค `/api/upload_and_train` - Upload and train on new invoices" | |
| ] | |
| info = f"""**Gradio API Information:** | |
| **Available Endpoints:** | |
| {chr(10).join(api_endpoints)} | |
| **API Status:** โ Active | |
| **Endpoint Count:** {len(api_endpoints)} | |
| **Usage Examples:** | |
| **Python:** | |
| ```python | |
| import requests | |
| # Query invoices | |
| response = requests.post("http://localhost:7860/api/predict", json={{ | |
| "data": ["What are all invoice numbers?", "header,totals"], | |
| "fn_index": 0 # api_query_invoice_info function index | |
| }}) | |
| # Get summary | |
| response = requests.post("http://localhost:7860/api/predict", json={{ | |
| "data": [], | |
| "fn_index": 1 # api_get_invoice_summary function index | |
| }}) | |
| ``` | |
| **cURL:** | |
| ```bash | |
| # Query invoices | |
| curl -X POST "http://localhost:7860/api/predict" \\ | |
| -H "Content-Type: application/json" \\ | |
| -d '{{"data": ["Extract vendor information", "vendor"], "fn_index": 0}}' | |
| # Get invoice summary | |
| curl -X POST "http://localhost:7860/api/predict" \\ | |
| -H "Content-Type: application/json" \\ | |
| -d '{{"data": [], "fn_index": 1}}' | |
| ``` | |
| **Base URL:** `http://localhost:7860` | |
| **API Documentation:** Available at `http://localhost:7860/docs` | |
| """ | |
| return info | |
| def create_interface(self): | |
| """Create the Gradio interface with built-in API support""" | |
| with gr.Blocks(title="Invoice RAG System with API", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# ๐ Invoice RAG System with Gradio API") | |
| gr.Markdown("Train on invoice PDFs and query them using different language models or API endpoints") | |
| with gr.Tabs(): | |
| # Training Tab | |
| with gr.TabItem("๐ฏ Training"): | |
| gr.Markdown("## Train RAG Model") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### ๐ค Upload Invoice PDFs") | |
| upload_files = gr.File( | |
| label="Upload Invoice PDFs", | |
| file_count="multiple", | |
| file_types=[".pdf"], | |
| height=200 | |
| ) | |
| upload_train_btn = gr.Button("๐ Upload & Train", variant="primary") | |
| with gr.Column(): | |
| gr.Markdown("### ๐ Train from Folder") | |
| invoice_folder = gr.Textbox( | |
| label="Invoice Folder Path", | |
| placeholder="Path to folder containing PDF invoices" | |
| ) | |
| folder_train_btn = gr.Button("๐ Train from Folder", variant="secondary") | |
| training_status = gr.Textbox( | |
| label="Training Status", | |
| interactive=False, | |
| lines=4 | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| summary_output = gr.Code( | |
| label="Pattern Summary", | |
| language="json", | |
| lines=10 | |
| ) | |
| with gr.Column(): | |
| history_output = gr.Markdown( | |
| label="Training History" | |
| ) | |
| gr.Markdown("### ๐พ Save/Load Model") | |
| with gr.Row(): | |
| with gr.Column(): | |
| save_path = gr.Textbox( | |
| label="Save Path", | |
| placeholder="model_name.pkl" | |
| ) | |
| save_btn = gr.Button("๐พ Save Model") | |
| save_status = gr.Textbox( | |
| label="Save Status", | |
| interactive=False | |
| ) | |
| with gr.Column(): | |
| model_path = gr.Textbox( | |
| label="Model Path", | |
| placeholder="Path to saved model (.pkl)" | |
| ) | |
| load_btn = gr.Button("๐ฅ Load Model") | |
| # Query Tab | |
| with gr.TabItem("๐ Query"): | |
| gr.Markdown("## Query Invoice Data") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| query_input = gr.Textbox( | |
| label="Your Question", | |
| placeholder="What are the invoice numbers?", | |
| lines=2 | |
| ) | |
| provider_dropdown = gr.Dropdown( | |
| choices=self.llm_manager.get_available_providers(), | |
| label="LLM Provider", | |
| value=self.llm_manager.get_available_providers()[0] if self.llm_manager.get_available_providers() else None | |
| ) | |
| model_dropdown = gr.Dropdown( | |
| label="Model", | |
| choices=self.llm_manager.get_models_for_provider( | |
| self.llm_manager.get_available_providers()[0] if self.llm_manager.get_available_providers() else "" | |
| ) if self.llm_manager.get_available_providers() else [] | |
| ) | |
| with gr.Column(scale=1): | |
| context_sections = gr.CheckboxGroup( | |
| choices=["header", "vendor", "client", "items", "totals", "footer"], | |
| label="Context Sections", | |
| info="Leave empty for all sections" | |
| ) | |
| top_k = gr.Slider( | |
| minimum=1, maximum=20, value=5, step=1, | |
| label="Top K Results" | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.0, maximum=2.0, value=0.7, step=0.1, | |
| label="Temperature" | |
| ) | |
| max_tokens = gr.Slider( | |
| minimum=100, maximum=8192, value=4096, step=100, | |
| label="Max Tokens" | |
| ) | |
| query_btn = gr.Button("๐ค Query RAG System", variant="primary") | |
| with gr.Row(): | |
| with gr.Column(): | |
| llm_response = gr.Textbox( | |
| label="LLM Response", | |
| lines=10, | |
| interactive=False | |
| ) | |
| with gr.Column(): | |
| rag_context = gr.Markdown( | |
| label="RAG Context" | |
| ) | |
| patterns_output = gr.Code( | |
| label="Extracted Patterns", | |
| language="json", | |
| lines=5 | |
| ) | |
| # API Tools Tab | |
| with gr.TabItem("๐ง API Tools"): | |
| gr.Markdown("## Test API Functions Directly") | |
| gr.Markdown("These functions are exposed via Gradio's built-in API system") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Query Invoice Info") | |
| api_query = gr.Textbox( | |
| label="Query", | |
| placeholder="What are all the invoice numbers?" | |
| ) | |
| api_sections = gr.Textbox( | |
| label="Context Sections (comma-separated)", | |
| placeholder="header,vendor,totals", | |
| info="Optional: specify which sections to focus on" | |
| ) | |
| api_query_btn = gr.Button("๐ Run API Query") | |
| api_query_output = gr.Code(language="json", lines=8) | |
| with gr.Column(): | |
| gr.Markdown("### Extract Specific Field") | |
| field_name = gr.Textbox( | |
| label="Field Name", | |
| placeholder="invoice_number, total, vendor_name" | |
| ) | |
| invoice_source = gr.Textbox( | |
| label="Invoice Source (optional)", | |
| placeholder="Leave empty to search all invoices" | |
| ) | |
| extract_btn = gr.Button("๐ Extract Field") | |
| extract_output = gr.Code(language="json", lines=8) | |
| with gr.Row(): | |
| with gr.Column(): | |
| summary_btn = gr.Button("๐ Get Invoice Summary") | |
| summary_api_output = gr.Code(language="json", lines=6) | |
| with gr.Column(): | |
| list_btn = gr.Button("๐ List Available Invoices") | |
| list_output = gr.Code(language="json", lines=6) | |
| # Status Tab | |
| with gr.TabItem("๐ Status & API"): | |
| gr.Markdown("## System Status & API Information") | |
| with gr.Row(): | |
| status_btn = gr.Button("๐ Refresh Status") | |
| mcp_info_btn = gr.Button("๐ Get API Info") | |
| with gr.Row(): | |
| with gr.Column(): | |
| status_output = gr.Markdown() | |
| with gr.Column(): | |
| mcp_info_output = gr.Markdown() | |
| # Predefined queries | |
| gr.Markdown("## ๐ Example Queries") | |
| example_queries = gr.Examples( | |
| examples=[ | |
| ["What are all the invoice numbers?"], | |
| ["Show me vendor information"], | |
| ["Extract total amounts from all invoices"], | |
| ["Find products with quantities and prices"], | |
| ["What are the invoice dates?"], | |
| ["List all companies mentioned in the invoices"], | |
| ["What payment terms are mentioned?"], | |
| ["Extract line items with descriptions and amounts"] | |
| ], | |
| inputs=[query_input], | |
| label="Click to use example queries" | |
| ) | |
| # Event handlers | |
| def update_models(provider): | |
| if provider: | |
| return gr.Dropdown(choices=self.llm_manager.get_models_for_provider(provider)) | |
| return gr.Dropdown(choices=[]) | |
| provider_dropdown.change( | |
| update_models, | |
| inputs=[provider_dropdown], | |
| outputs=[model_dropdown] | |
| ) | |
| upload_train_btn.click( | |
| self.upload_and_train_files, | |
| inputs=[upload_files], | |
| outputs=[training_status, summary_output, history_output] | |
| ) | |
| folder_train_btn.click( | |
| self.train_rag_system, | |
| inputs=[invoice_folder], | |
| outputs=[training_status, summary_output, history_output] | |
| ) | |
| load_btn.click( | |
| self.load_model, | |
| inputs=[model_path], | |
| outputs=[training_status, summary_output, history_output] | |
| ) | |
| save_btn.click( | |
| self.save_model, | |
| inputs=[save_path], | |
| outputs=[save_status] | |
| ) | |
| query_btn.click( | |
| self.query_invoices, | |
| inputs=[ | |
| query_input, provider_dropdown, model_dropdown, | |
| context_sections, top_k, temperature, max_tokens | |
| ], | |
| outputs=[llm_response, rag_context, patterns_output] | |
| ) | |
| # MCP Tool handlers | |
| api_query_btn.click( | |
| self.api_query_invoice_info, | |
| inputs=[api_query, api_sections], | |
| outputs=[api_query_output] | |
| ) | |
| extract_btn.click( | |
| self.api_extract_specific_field, | |
| inputs=[field_name, invoice_source], | |
| outputs=[extract_output] | |
| ) | |
| summary_btn.click( | |
| self.api_get_invoice_summary, | |
| outputs=[summary_api_output] | |
| ) | |
| list_btn.click( | |
| self.api_get_invoice_summary, | |
| outputs=[list_output] | |
| ) | |
| status_btn.click( | |
| self.get_system_status, | |
| outputs=[status_output] | |
| ) | |
| mcp_info_btn.click( | |
| self.get_api_info, | |
| outputs=[mcp_info_output] | |
| ) | |
| # Initialize status on load | |
| demo.load( | |
| lambda: (self.get_system_status(), self.get_api_info()), | |
| outputs=[status_output, mcp_info_output] | |
| ) | |
| return demo | |
| def main(): | |
| """Main function optimized for HF Spaces""" | |
| # Setup | |
| setup_environment() | |
| # Check API keys with HF Spaces support | |
| required_vars = { | |
| "GROQ_API_KEY": "Groq API", | |
| } | |
| available_apis = [] | |
| for var, name in required_vars.items(): | |
| # Check both environment and HF Spaces secrets | |
| if os.getenv(var) or os.getenv(f"HF_{var}"): | |
| available_apis.append(name) | |
| # Use HF secret if available | |
| if os.getenv(f"HF_{var}") and not os.getenv(var): | |
| os.environ[var] = os.getenv(f"HF_{var}") | |
| if not available_apis: | |
| print("โ ๏ธ Warning: No API keys found.") | |
| print("Set GROQ_API_KEY in HF Spaces secrets or environment") | |
| # Create interface | |
| interface = InvoiceRAGInterface() | |
| demo = interface.create_interface() | |
| print("๐ Starting Invoice RAG System on Hugging Face Spaces...") | |
| # HF Spaces optimized launch | |
| demo.launch( | |
| server_name="0.0.0.0", # Listen on all network interfaces | |
| server_port=7860, # Default Gradio port | |
| share=True, # Enable sharing | |
| debug=False, # Disable debug mode in production | |
| auth=None, # No authentication required | |
| show_api=True, # Show API documentation | |
| max_threads=40, # Limit concurrent threads | |
| ) | |
| if __name__ == "__main__": | |
| main() |