YashArya16 commited on
Commit
9ec2c50
Β·
verified Β·
1 Parent(s): 397b599

Upload invoice_rag_gradio_api.py

Browse files
Files changed (1) hide show
  1. invoice_rag_gradio_api.py +940 -0
invoice_rag_gradio_api.py ADDED
@@ -0,0 +1,940 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json
3
+ import os
4
+ import asyncio
5
+ import shutil
6
+ import tempfile
7
+ from typing import Dict, Any, Optional, List, Union
8
+ import logging
9
+ from datetime import datetime
10
+
11
+ # LLM integrations
12
+ import groq
13
+
14
+ # Import the RAG system
15
+ from invoice_rag_system import InvoiceRAGSystem
16
+
17
+ # Setup logging
18
+ logging.basicConfig(level=logging.INFO)
19
+ logger = logging.getLogger("invoice-rag-gradio")
20
+
21
+ def setup_environment():
22
+ """Setup environment for HF Spaces"""
23
+ # Set default paths for HF Spaces
24
+ if not os.path.exists("sample_invoices"):
25
+ os.makedirs("sample_invoices")
26
+
27
+ # Check for HF Spaces environment
28
+ if os.getenv("SPACE_ID"):
29
+ print(f"πŸš€ Running on Hugging Face Spaces: {os.getenv('SPACE_ID')}")
30
+
31
+ return True
32
+
33
+ class LLMManager:
34
+ """Manage different LLM providers"""
35
+
36
+ def __init__(self):
37
+ self.providers = {
38
+ "groq": {
39
+ "client": None,
40
+ "models": ["llama-3.3-70b-versatile", "mixtral-8x7b-32768", "llama-3.1-8b-instant"],
41
+ "api_key_env": "GROQ_API_KEY"
42
+ },
43
+ }
44
+ self.initialize_clients()
45
+
46
+ def initialize_clients(self):
47
+ """Initialize LLM clients based on available API keys"""
48
+ # Groq
49
+ if os.getenv(self.providers["groq"]["api_key_env"]):
50
+ try:
51
+ self.providers["groq"]["client"] = groq.Client(
52
+ api_key=os.getenv(self.providers["groq"]["api_key_env"])
53
+ )
54
+ logger.info("Groq client initialized")
55
+ except Exception as e:
56
+ logger.error(f"Failed to initialize Groq client: {e}")
57
+
58
+ def get_available_providers(self) -> List[str]:
59
+ """Get list of available providers"""
60
+ return [provider for provider, config in self.providers.items()
61
+ if config["client"] is not None]
62
+
63
+ def get_models_for_provider(self, provider: str) -> List[str]:
64
+ """Get available models for a provider"""
65
+ if provider in self.providers and self.providers[provider]["client"]:
66
+ return self.providers[provider]["models"]
67
+ return []
68
+
69
+ def generate_response(self, provider: str, model: str, prompt: str,
70
+ max_tokens: int = 4096, temperature: float = 0.7) -> str:
71
+ """Generate response using specified provider and model"""
72
+ try:
73
+ if provider == "groq":
74
+ response = self.providers[provider]["client"].chat.completions.create(
75
+ messages=[{"role": "user", "content": prompt}],
76
+ model=model,
77
+ max_tokens=max_tokens,
78
+ temperature=temperature,
79
+ )
80
+ return response.choices[0].message.content.strip()
81
+ else:
82
+ return f"Error: Provider {provider} not supported or not initialized"
83
+
84
+ except Exception as e:
85
+ logger.error(f"Error generating response with {provider}/{model}: {e}")
86
+ return f"Error: {str(e)}"
87
+
88
+
89
+ class InvoiceRAGInterface:
90
+ """Gradio interface for Invoice RAG system with built-in API"""
91
+
92
+ def __init__(self):
93
+ setup_environment()
94
+ self.rag_system = InvoiceRAGSystem()
95
+ self.llm_manager = LLMManager()
96
+ self.is_trained = False
97
+ self.training_history = []
98
+ self.temp_upload_dir = tempfile.mkdtemp()
99
+
100
+ # API Functions (exposed via Gradio's built-in API)
101
+ def api_query_invoice_info(self, query: str, context_sections: str = None) -> str:
102
+ """Extract information from invoices using the RAG system.
103
+
104
+ Args:
105
+ query: The question to ask about the invoices
106
+ context_sections: Comma-separated list of sections to focus on (header,vendor,client,items,totals,footer)
107
+
108
+ Returns:
109
+ Extracted information and patterns from the invoice data
110
+ """
111
+ if not self.is_trained:
112
+ return json.dumps({"error": "RAG system not trained. Please train the system first with invoice PDFs."})
113
+
114
+ if not query.strip():
115
+ return json.dumps({"error": "Please provide a query"})
116
+
117
+ try:
118
+ # Parse context sections
119
+ sections = None
120
+ if context_sections:
121
+ sections = [s.strip() for s in context_sections.split(',') if s.strip()]
122
+
123
+ # Extract information using RAG
124
+ rag_results = self.rag_system.extract_invoice_info(query, sections)
125
+
126
+ # Format response
127
+ response = {
128
+ "success": True,
129
+ "query": query,
130
+ "sources_found": rag_results['num_sources'],
131
+ "chunks_retrieved": len(rag_results['context_chunks']),
132
+ "extracted_patterns": rag_results['extracted_patterns'],
133
+ "relevant_chunks": [
134
+ {
135
+ "source": chunk['source'],
136
+ "type": chunk['type'],
137
+ "content": chunk['content'][:500] + "..." if len(chunk['content']) > 500 else chunk['content'],
138
+ "relevance_score": chunk['score']
139
+ }
140
+ for chunk in rag_results['context_chunks'][:5]
141
+ ]
142
+ }
143
+
144
+ return json.dumps(response, indent=2, ensure_ascii=False)
145
+
146
+ except Exception as e:
147
+ logger.error(f"API Query error: {e}")
148
+ return json.dumps({"error": f"Query failed: {str(e)}"})
149
+
150
+ def api_get_invoice_summary(self) -> str:
151
+ """Get a summary of all processed invoices and their patterns."""
152
+ if not self.is_trained:
153
+ return json.dumps({"error": "RAG system not trained. Please train the system first with invoice PDFs."})
154
+
155
+ try:
156
+ summary = self.rag_system.get_pattern_summary()
157
+ return json.dumps({"success": True, "summary": summary}, indent=2, ensure_ascii=False)
158
+ except Exception as e:
159
+ return json.dumps({"error": f"Failed to get summary: {str(e)}"})
160
+
161
+ def api_extract_specific_field(self, field_name: str, invoice_source: str = None) -> str:
162
+ """Extract a specific field from invoices.
163
+
164
+ Args:
165
+ field_name: The field to extract (e.g., 'invoice_number', 'total', 'vendor_name')
166
+ invoice_source: Optional specific invoice to search in
167
+ """
168
+ if not self.is_trained:
169
+ return json.dumps({"error": "RAG system not trained. Please train the system first with invoice PDFs."})
170
+
171
+ try:
172
+ query = f"Find all {field_name} values"
173
+ if invoice_source:
174
+ query += f" from {invoice_source}"
175
+
176
+ rag_results = self.rag_system.extract_invoice_info(query)
177
+
178
+ # Extract the specific field from patterns
179
+ field_values = []
180
+ for pattern in rag_results['extracted_patterns']:
181
+ if field_name.lower() in str(pattern).lower():
182
+ field_values.append(pattern)
183
+
184
+ result = {
185
+ "success": True,
186
+ "field": field_name,
187
+ "values_found": len(field_values),
188
+ "values": field_values,
189
+ "source_invoices": rag_results['num_sources']
190
+ }
191
+
192
+ return json.dumps(result, indent=2, ensure_ascii=False)
193
+
194
+ except Exception as e:
195
+ return json.dumps({"error": f"Field extraction failed: {str(e)}"})
196
+
197
+ def api_list_available_invoices(self) -> str:
198
+ """List all available invoices in the RAG system."""
199
+ if not self.is_trained:
200
+ return json.dumps({"error": "RAG system not trained. Please train the system first with invoice PDFs."})
201
+
202
+ try:
203
+ # Get unique sources from chunks
204
+ sources = set()
205
+ chunk_counts = {}
206
+
207
+ for chunk in self.rag_system.chunks:
208
+ source = chunk.source_file
209
+ sources.add(source)
210
+ chunk_counts[source] = chunk_counts.get(source, 0) + 1
211
+
212
+ result = {
213
+ "success": True,
214
+ "total_invoices": len(sources),
215
+ "total_chunks": len(self.rag_system.chunks),
216
+ "invoices": [
217
+ {
218
+ "source": source,
219
+ "chunks": chunk_counts.get(source, 0)
220
+ }
221
+ for source in sorted(sources)
222
+ ]
223
+ }
224
+
225
+ return json.dumps(result, indent=2, ensure_ascii=False)
226
+
227
+ except Exception as e:
228
+ return json.dumps({"error": f"Failed to list invoices: {str(e)}"})
229
+
230
+ def api_upload_and_train(self, files: List[str]) -> str:
231
+ """Upload invoices and train the RAG system.
232
+
233
+ Args:
234
+ files: List of file paths to invoice PDFs
235
+ """
236
+ try:
237
+ if not files:
238
+ return json.dumps({"error": "No files provided"})
239
+
240
+ # Create a temporary directory for this training session
241
+ training_dir = tempfile.mkdtemp()
242
+
243
+ # Copy uploaded files to training directory
244
+ pdf_count = 0
245
+ for file_path in files:
246
+ if file_path and os.path.exists(file_path) and file_path.lower().endswith('.pdf'):
247
+ filename = os.path.basename(file_path)
248
+ shutil.copy2(file_path, os.path.join(training_dir, filename))
249
+ pdf_count += 1
250
+
251
+ if pdf_count == 0:
252
+ return json.dumps({"error": "No valid PDF files found"})
253
+
254
+ # Train the system
255
+ self.rag_system.train_on_invoices(training_dir)
256
+ self.is_trained = True
257
+
258
+ # Get summary
259
+ summary = self.rag_system.get_pattern_summary()
260
+
261
+ # Update training history
262
+ self.training_history.append({
263
+ 'timestamp': datetime.now().isoformat(),
264
+ 'method': 'upload_and_train',
265
+ 'num_invoices': summary['total_invoices'],
266
+ 'num_chunks': summary['total_chunks']
267
+ })
268
+
269
+ # Clean up temporary directory
270
+ shutil.rmtree(training_dir)
271
+
272
+ result = {
273
+ "success": True,
274
+ "message": f"Training completed successfully! Processed {pdf_count} PDF files.",
275
+ "invoices_processed": summary['total_invoices'],
276
+ "chunks_created": summary['total_chunks'],
277
+ "summary": summary
278
+ }
279
+
280
+ return json.dumps(result, indent=2, ensure_ascii=False)
281
+
282
+ except Exception as e:
283
+ logger.error(f"Upload and train error: {e}")
284
+ return json.dumps({"error": f"Training failed: {str(e)}"})
285
+
286
+ # Regular Interface Functions
287
+ def upload_and_train_files(self, files, progress=gr.Progress()) -> tuple:
288
+ """Handle file upload and training"""
289
+ if not files:
290
+ return "❌ No files uploaded", "", ""
291
+
292
+ try:
293
+ progress(0, desc="Processing uploaded files...")
294
+
295
+ # Filter PDF files
296
+ pdf_files = [f for f in files if f.name.lower().endswith('.pdf')]
297
+ if not pdf_files:
298
+ return "❌ No PDF files found in upload", "", ""
299
+
300
+ progress(0.2, desc=f"Found {len(pdf_files)} PDF files")
301
+
302
+ # Create temporary directory and copy files
303
+ training_dir = tempfile.mkdtemp()
304
+ for pdf_file in pdf_files:
305
+ filename = os.path.basename(pdf_file.name)
306
+ shutil.copy2(pdf_file.name, os.path.join(training_dir, filename))
307
+
308
+ progress(0.4, desc="Training RAG system...")
309
+
310
+ # Train the system
311
+ self.rag_system.train_on_invoices(training_dir)
312
+ progress(0.8, desc="Building index...")
313
+
314
+ self.is_trained = True
315
+
316
+ # Get summary
317
+ summary = self.rag_system.get_pattern_summary()
318
+ progress(1.0, desc="Training complete!")
319
+
320
+ # Update training history
321
+ self.training_history.append({
322
+ 'timestamp': datetime.now().isoformat(),
323
+ 'method': 'file_upload',
324
+ 'num_invoices': summary['total_invoices'],
325
+ 'num_chunks': summary['total_chunks']
326
+ })
327
+
328
+ # Clean up
329
+ shutil.rmtree(training_dir)
330
+
331
+ status = f"βœ… Training completed successfully!\n" \
332
+ f"πŸ“ Processed {len(pdf_files)} PDF files\n" \
333
+ f"πŸ“„ Created {summary['total_chunks']} chunks\n" \
334
+ f"πŸš€ API endpoints are now available!"
335
+
336
+ summary_text = json.dumps(summary, indent=2, ensure_ascii=False)
337
+
338
+ return status, summary_text, self.format_training_history()
339
+
340
+ except Exception as e:
341
+ logger.error(f"Upload training error: {e}")
342
+ return f"❌ Training failed: {str(e)}", "", ""
343
+
344
+ def train_rag_system(self, invoice_folder: str, progress=gr.Progress()) -> tuple:
345
+ """Train the RAG system on invoice folder"""
346
+ if not invoice_folder or not os.path.exists(invoice_folder):
347
+ return "❌ Invalid folder path", "", ""
348
+
349
+ try:
350
+ progress(0, desc="Starting training...")
351
+
352
+ # Count PDF files
353
+ pdf_files = [f for f in os.listdir(invoice_folder) if f.endswith('.pdf')]
354
+ if not pdf_files:
355
+ return "❌ No PDF files found in folder", "", ""
356
+
357
+ progress(0.2, desc=f"Found {len(pdf_files)} PDF files")
358
+
359
+ # Train the system
360
+ self.rag_system.train_on_invoices(invoice_folder)
361
+ progress(0.8, desc="Building index...")
362
+
363
+ self.is_trained = True
364
+
365
+ # Get summary
366
+ summary = self.rag_system.get_pattern_summary()
367
+ progress(1.0, desc="Training complete!")
368
+
369
+ # Update training history
370
+ self.training_history.append({
371
+ 'timestamp': datetime.now().isoformat(),
372
+ 'method': 'folder_training',
373
+ 'folder': invoice_folder,
374
+ 'num_invoices': summary['total_invoices'],
375
+ 'num_chunks': summary['total_chunks']
376
+ })
377
+
378
+ status = f"βœ… Training completed successfully!\n" \
379
+ f"πŸ“ Processed {summary['total_invoices']} invoices\n" \
380
+ f"οΏ½οΏ½οΏ½ Created {summary['total_chunks']} chunks\n" \
381
+ f"πŸš€ API endpoints are now available!"
382
+
383
+ summary_text = json.dumps(summary, indent=2, ensure_ascii=False)
384
+
385
+ return status, summary_text, self.format_training_history()
386
+
387
+ except Exception as e:
388
+ logger.error(f"Training error: {e}")
389
+ return f"❌ Training failed: {str(e)}", "", ""
390
+
391
+ def load_model(self, model_path: str) -> tuple:
392
+ """Load a pre-trained model"""
393
+ if not model_path or not os.path.exists(model_path):
394
+ return "❌ Invalid model path", "", ""
395
+
396
+ try:
397
+ self.rag_system.load_model(model_path)
398
+ self.is_trained = True
399
+
400
+ summary = self.rag_system.get_pattern_summary()
401
+
402
+ status = f"βœ… Model loaded successfully!\n" \
403
+ f"πŸ“ Loaded {summary['total_invoices']} invoices\n" \
404
+ f"πŸ“„ {summary['total_chunks']} chunks available\n" \
405
+ f"πŸš€ API endpoints are now available!"
406
+
407
+ summary_text = json.dumps(summary, indent=2, ensure_ascii=False)
408
+
409
+ return status, summary_text, self.format_training_history()
410
+
411
+ except Exception as e:
412
+ logger.error(f"Model loading error: {e}")
413
+ return f"❌ Failed to load model: {str(e)}", "", ""
414
+
415
+ def save_model(self, save_path: str) -> str:
416
+ """Save the current model"""
417
+ if not self.is_trained:
418
+ return "❌ No trained model to save"
419
+
420
+ if not save_path:
421
+ return "❌ Please provide a save path"
422
+
423
+ try:
424
+ # Ensure .pkl extension
425
+ if not save_path.endswith('.pkl'):
426
+ save_path += '.pkl'
427
+
428
+ self.rag_system.save_model(save_path)
429
+ return f"βœ… Model saved to {save_path}"
430
+
431
+ except Exception as e:
432
+ logger.error(f"Model saving error: {e}")
433
+ return f"❌ Failed to save model: {str(e)}"
434
+
435
+ def query_invoices(self, query: str, provider: str, model: str,
436
+ context_sections: List[str], top_k: int,
437
+ temperature: float, max_tokens: int) -> tuple:
438
+ """Query the invoice RAG system"""
439
+ if not self.is_trained:
440
+ return "❌ RAG system not trained. Please train or load a model first.", "", ""
441
+
442
+ if not query.strip():
443
+ return "❌ Please enter a query", "", ""
444
+
445
+ if not provider or provider not in self.llm_manager.get_available_providers():
446
+ return "❌ Please select a valid LLM provider", "", ""
447
+
448
+ try:
449
+ # Extract information using RAG
450
+ rag_results = self.rag_system.extract_invoice_info(
451
+ query,
452
+ context_sections if context_sections else None
453
+ )
454
+
455
+ # Prepare context for LLM
456
+ context_chunks = rag_results['context_chunks'][:top_k]
457
+ context_text = "\n\n".join(
458
+ f"[{chunk['type']}] From {chunk['source']}:\n{chunk['content']}"
459
+ for chunk in context_chunks
460
+ )
461
+
462
+ # Create prompt for LLM
463
+ prompt = f"""Based on the following invoice data, please answer the user's question.
464
+
465
+ Context from invoices:
466
+ {context_text}
467
+
468
+ Extracted patterns:
469
+ {json.dumps(rag_results['extracted_patterns'], indent=2)}
470
+
471
+ User question: {query}
472
+
473
+ Please provide a detailed and accurate answer based on the invoice data provided. If you cannot find specific information in the context, please mention that."""
474
+
475
+ # Generate response using selected LLM
476
+ llm_response = self.llm_manager.generate_response(
477
+ provider, model, prompt, max_tokens, temperature
478
+ )
479
+
480
+ # Format RAG context info
481
+ rag_info = f"""**RAG Context Retrieved:**
482
+ - Sources: {rag_results['num_sources']} invoices
483
+ - Chunks: {len(context_chunks)} relevant sections
484
+ - Sections: {', '.join(set(chunk['type'] for chunk in context_chunks))}
485
+
486
+ **Top Retrieved Chunks:**
487
+ """
488
+
489
+ for i, chunk in enumerate(context_chunks[:3], 1):
490
+ rag_info += f"\n{i}. [{chunk['type']}] {chunk['source']} (Score: {chunk['score']:.3f})\n"
491
+ rag_info += f" {chunk['content'][:200]}{'...' if len(chunk['content']) > 200 else ''}\n"
492
+
493
+ return llm_response, rag_info, json.dumps(rag_results['extracted_patterns'], indent=2)
494
+
495
+ except Exception as e:
496
+ logger.error(f"Query error: {e}")
497
+ return f"❌ Query failed: {str(e)}", "", ""
498
+
499
+ def format_training_history(self) -> str:
500
+ """Format training history for display"""
501
+ if not self.training_history:
502
+ return "No training history available"
503
+
504
+ history = "**Training History:**\n\n"
505
+ for i, entry in enumerate(reversed(self.training_history), 1):
506
+ history += f"{i}. **{entry['timestamp'][:19]}**\n"
507
+ history += f" πŸ”§ Method: {entry['method'].replace('_', ' ').title()}\n"
508
+ if 'folder' in entry:
509
+ history += f" πŸ“ Folder: {entry['folder']}\n"
510
+ history += f" πŸ“Š {entry['num_invoices']} invoices, {entry['num_chunks']} chunks\n\n"
511
+
512
+ return history
513
+
514
+ def get_system_status(self) -> str:
515
+ """Get current system status"""
516
+ available_providers = self.llm_manager.get_available_providers()
517
+
518
+ status = f"""**System Status:**
519
+
520
+ **RAG System:**
521
+ - Trained: {'βœ… Yes' if self.is_trained else '❌ No'}
522
+ - Chunks: {len(self.rag_system.chunks) if self.is_trained else 0}
523
+ - Index: {'βœ… Built' if self.rag_system.index is not None else '❌ Not built'}
524
+
525
+ **Gradio API:**
526
+ - Status: {'βœ… Active' if self.is_trained else '⏳ Waiting for training'}
527
+ - Available Endpoints: {'4 endpoints ready' if self.is_trained else 'Training required'}
528
+
529
+ **Available LLM Providers:**
530
+ """
531
+
532
+ for provider in available_providers:
533
+ models = self.llm_manager.get_models_for_provider(provider)
534
+ status += f"- **{provider.upper()}**: {', '.join(models)}\n"
535
+
536
+ if not available_providers:
537
+ status += "❌ No LLM providers configured. Please set API keys.\n"
538
+
539
+ return status
540
+
541
+ def get_api_info(self) -> str:
542
+ """Get API endpoint information"""
543
+ if not self.is_trained:
544
+ return "❌ API endpoints not available - RAG system not trained"
545
+
546
+ api_endpoints = [
547
+ "πŸ” `/api/query_invoice_info` - Extract information from invoices",
548
+ "πŸ“‹ `/api/get_invoice_summary` - Get summary of all processed invoices",
549
+ "πŸ”Ž `/api/extract_specific_field` - Extract specific fields from invoices",
550
+ "πŸ“„ `/api/list_available_invoices` - List all available invoice sources",
551
+ "πŸ“€ `/api/upload_and_train` - Upload and train on new invoices"
552
+ ]
553
+
554
+ info = f"""**Gradio API Information:**
555
+
556
+ **Available Endpoints:**
557
+ {chr(10).join(api_endpoints)}
558
+
559
+ **API Status:** βœ… Active
560
+ **Endpoint Count:** {len(api_endpoints)}
561
+
562
+ **Usage Examples:**
563
+
564
+ **Python:**
565
+ ```python
566
+ import requests
567
+
568
+ # Query invoices
569
+ response = requests.post("http://localhost:7860/api/predict", json={{
570
+ "data": ["What are all invoice numbers?", "header,totals"],
571
+ "fn_index": 0 # api_query_invoice_info function index
572
+ }})
573
+
574
+ # Get summary
575
+ response = requests.post("http://localhost:7860/api/predict", json={{
576
+ "data": [],
577
+ "fn_index": 1 # api_get_invoice_summary function index
578
+ }})
579
+ ```
580
+
581
+ **cURL:**
582
+ ```bash
583
+ # Query invoices
584
+ curl -X POST "http://localhost:7860/api/predict" \\
585
+ -H "Content-Type: application/json" \\
586
+ -d '{{"data": ["Extract vendor information", "vendor"], "fn_index": 0}}'
587
+
588
+ # Get invoice summary
589
+ curl -X POST "http://localhost:7860/api/predict" \\
590
+ -H "Content-Type: application/json" \\
591
+ -d '{{"data": [], "fn_index": 1}}'
592
+ ```
593
+
594
+ **Base URL:** `http://localhost:7860`
595
+ **API Documentation:** Available at `http://localhost:7860/docs`
596
+ """
597
+
598
+ return info
599
+
600
+ def create_interface(self):
601
+ """Create the Gradio interface with built-in API support"""
602
+
603
+ with gr.Blocks(title="Invoice RAG System with API", theme=gr.themes.Soft()) as demo:
604
+
605
+ gr.Markdown("# πŸ“„ Invoice RAG System with Gradio API")
606
+ gr.Markdown("Train on invoice PDFs and query them using different language models or API endpoints")
607
+
608
+ with gr.Tabs():
609
+
610
+ # Training Tab
611
+ with gr.TabItem("🎯 Training"):
612
+ gr.Markdown("## Train RAG Model")
613
+
614
+ with gr.Row():
615
+ with gr.Column():
616
+ gr.Markdown("### πŸ“€ Upload Invoice PDFs")
617
+ upload_files = gr.File(
618
+ label="Upload Invoice PDFs",
619
+ file_count="multiple",
620
+ file_types=[".pdf"],
621
+ height=200
622
+ )
623
+ upload_train_btn = gr.Button("πŸš€ Upload & Train", variant="primary")
624
+
625
+ with gr.Column():
626
+ gr.Markdown("### πŸ“ Train from Folder")
627
+ invoice_folder = gr.Textbox(
628
+ label="Invoice Folder Path",
629
+ placeholder="Path to folder containing PDF invoices"
630
+ )
631
+ folder_train_btn = gr.Button("πŸš€ Train from Folder", variant="secondary")
632
+
633
+ training_status = gr.Textbox(
634
+ label="Training Status",
635
+ interactive=False,
636
+ lines=4
637
+ )
638
+
639
+ with gr.Row():
640
+ with gr.Column():
641
+ summary_output = gr.Code(
642
+ label="Pattern Summary",
643
+ language="json",
644
+ lines=10
645
+ )
646
+
647
+ with gr.Column():
648
+ history_output = gr.Markdown(
649
+ label="Training History"
650
+ )
651
+
652
+ gr.Markdown("### πŸ’Ύ Save/Load Model")
653
+ with gr.Row():
654
+ with gr.Column():
655
+ save_path = gr.Textbox(
656
+ label="Save Path",
657
+ placeholder="model_name.pkl"
658
+ )
659
+ save_btn = gr.Button("πŸ’Ύ Save Model")
660
+ save_status = gr.Textbox(
661
+ label="Save Status",
662
+ interactive=False
663
+ )
664
+
665
+ with gr.Column():
666
+ model_path = gr.Textbox(
667
+ label="Model Path",
668
+ placeholder="Path to saved model (.pkl)"
669
+ )
670
+ load_btn = gr.Button("πŸ“₯ Load Model")
671
+
672
+ # Query Tab
673
+ with gr.TabItem("πŸ” Query"):
674
+ gr.Markdown("## Query Invoice Data")
675
+
676
+ with gr.Row():
677
+ with gr.Column(scale=2):
678
+ query_input = gr.Textbox(
679
+ label="Your Question",
680
+ placeholder="What are the invoice numbers?",
681
+ lines=2
682
+ )
683
+
684
+ provider_dropdown = gr.Dropdown(
685
+ choices=self.llm_manager.get_available_providers(),
686
+ label="LLM Provider",
687
+ value=self.llm_manager.get_available_providers()[0] if self.llm_manager.get_available_providers() else None
688
+ )
689
+
690
+ model_dropdown = gr.Dropdown(
691
+ label="Model",
692
+ choices=self.llm_manager.get_models_for_provider(
693
+ self.llm_manager.get_available_providers()[0] if self.llm_manager.get_available_providers() else ""
694
+ ) if self.llm_manager.get_available_providers() else []
695
+ )
696
+
697
+ with gr.Column(scale=1):
698
+ context_sections = gr.CheckboxGroup(
699
+ choices=["header", "vendor", "client", "items", "totals", "footer"],
700
+ label="Context Sections",
701
+ info="Leave empty for all sections"
702
+ )
703
+
704
+ top_k = gr.Slider(
705
+ minimum=1, maximum=20, value=5, step=1,
706
+ label="Top K Results"
707
+ )
708
+
709
+ temperature = gr.Slider(
710
+ minimum=0.0, maximum=2.0, value=0.7, step=0.1,
711
+ label="Temperature"
712
+ )
713
+
714
+ max_tokens = gr.Slider(
715
+ minimum=100, maximum=8192, value=4096, step=100,
716
+ label="Max Tokens"
717
+ )
718
+
719
+ query_btn = gr.Button("πŸ€– Query RAG System", variant="primary")
720
+
721
+ with gr.Row():
722
+ with gr.Column():
723
+ llm_response = gr.Textbox(
724
+ label="LLM Response",
725
+ lines=10,
726
+ interactive=False
727
+ )
728
+
729
+ with gr.Column():
730
+ rag_context = gr.Markdown(
731
+ label="RAG Context"
732
+ )
733
+
734
+ patterns_output = gr.Code(
735
+ label="Extracted Patterns",
736
+ language="json",
737
+ lines=5
738
+ )
739
+
740
+ # API Tools Tab
741
+ with gr.TabItem("πŸ”§ API Tools"):
742
+ gr.Markdown("## Test API Functions Directly")
743
+ gr.Markdown("These functions are exposed via Gradio's built-in API system")
744
+
745
+ with gr.Row():
746
+ with gr.Column():
747
+ gr.Markdown("### Query Invoice Info")
748
+ api_query = gr.Textbox(
749
+ label="Query",
750
+ placeholder="What are all the invoice numbers?"
751
+ )
752
+ api_sections = gr.Textbox(
753
+ label="Context Sections (comma-separated)",
754
+ placeholder="header,vendor,totals",
755
+ info="Optional: specify which sections to focus on"
756
+ )
757
+ api_query_btn = gr.Button("πŸ” Run API Query")
758
+ api_query_output = gr.Code(language="json", lines=8)
759
+
760
+ with gr.Column():
761
+ gr.Markdown("### Extract Specific Field")
762
+ field_name = gr.Textbox(
763
+ label="Field Name",
764
+ placeholder="invoice_number, total, vendor_name"
765
+ )
766
+ invoice_source = gr.Textbox(
767
+ label="Invoice Source (optional)",
768
+ placeholder="Leave empty to search all invoices"
769
+ )
770
+ extract_btn = gr.Button("πŸ”Ž Extract Field")
771
+ extract_output = gr.Code(language="json", lines=8)
772
+
773
+ with gr.Row():
774
+ with gr.Column():
775
+ summary_btn = gr.Button("πŸ“‹ Get Invoice Summary")
776
+ summary_api_output = gr.Code(language="json", lines=6)
777
+
778
+ with gr.Column():
779
+ list_btn = gr.Button("πŸ“„ List Available Invoices")
780
+ list_output = gr.Code(language="json", lines=6)
781
+
782
+ # Status Tab
783
+ with gr.TabItem("πŸ“Š Status & API"):
784
+ gr.Markdown("## System Status & API Information")
785
+
786
+ with gr.Row():
787
+ status_btn = gr.Button("πŸ”„ Refresh Status")
788
+ mcp_info_btn = gr.Button("πŸš€ Get API Info")
789
+
790
+ with gr.Row():
791
+ with gr.Column():
792
+ status_output = gr.Markdown()
793
+ with gr.Column():
794
+ mcp_info_output = gr.Markdown()
795
+
796
+ # Predefined queries
797
+ gr.Markdown("## πŸ“ Example Queries")
798
+ example_queries = gr.Examples(
799
+ examples=[
800
+ ["What are all the invoice numbers?"],
801
+ ["Show me vendor information"],
802
+ ["Extract total amounts from all invoices"],
803
+ ["Find products with quantities and prices"],
804
+ ["What are the invoice dates?"],
805
+ ["List all companies mentioned in the invoices"],
806
+ ["What payment terms are mentioned?"],
807
+ ["Extract line items with descriptions and amounts"]
808
+ ],
809
+ inputs=[query_input],
810
+ label="Click to use example queries"
811
+ )
812
+
813
+ # Event handlers
814
+ def update_models(provider):
815
+ if provider:
816
+ return gr.Dropdown(choices=self.llm_manager.get_models_for_provider(provider))
817
+ return gr.Dropdown(choices=[])
818
+
819
+ provider_dropdown.change(
820
+ update_models,
821
+ inputs=[provider_dropdown],
822
+ outputs=[model_dropdown]
823
+ )
824
+
825
+ upload_train_btn.click(
826
+ self.upload_and_train_files,
827
+ inputs=[upload_files],
828
+ outputs=[training_status, summary_output, history_output]
829
+ )
830
+
831
+ folder_train_btn.click(
832
+ self.train_rag_system,
833
+ inputs=[invoice_folder],
834
+ outputs=[training_status, summary_output, history_output]
835
+ )
836
+
837
+ load_btn.click(
838
+ self.load_model,
839
+ inputs=[model_path],
840
+ outputs=[training_status, summary_output, history_output]
841
+ )
842
+
843
+ save_btn.click(
844
+ self.save_model,
845
+ inputs=[save_path],
846
+ outputs=[save_status]
847
+ )
848
+
849
+ query_btn.click(
850
+ self.query_invoices,
851
+ inputs=[
852
+ query_input, provider_dropdown, model_dropdown,
853
+ context_sections, top_k, temperature, max_tokens
854
+ ],
855
+ outputs=[llm_response, rag_context, patterns_output]
856
+ )
857
+
858
+ # MCP Tool handlers
859
+ api_query_btn.click(
860
+ self.api_query_invoice_info,
861
+ inputs=[api_query, api_sections],
862
+ outputs=[api_query_output]
863
+ )
864
+
865
+ extract_btn.click(
866
+ self.api_extract_specific_field,
867
+ inputs=[field_name, invoice_source],
868
+ outputs=[extract_output]
869
+ )
870
+
871
+ summary_btn.click(
872
+ self.api_get_invoice_summary,
873
+ outputs=[summary_api_output]
874
+ )
875
+
876
+ list_btn.click(
877
+ self.api_get_invoice_summary,
878
+ outputs=[list_output]
879
+ )
880
+
881
+ status_btn.click(
882
+ self.get_system_status,
883
+ outputs=[status_output]
884
+ )
885
+
886
+ mcp_info_btn.click(
887
+ self.get_api_info,
888
+ outputs=[mcp_info_output]
889
+ )
890
+
891
+ # Initialize status on load
892
+ demo.load(
893
+ lambda: (self.get_system_status(), self.get_api_info()),
894
+ outputs=[status_output, mcp_info_output]
895
+ )
896
+
897
+ return demo
898
+
899
+ def main():
900
+ """Main function optimized for HF Spaces"""
901
+
902
+ # Setup
903
+ setup_environment()
904
+
905
+ # Check API keys with HF Spaces support
906
+ required_vars = {
907
+ "GROQ_API_KEY": "Groq API",
908
+ }
909
+
910
+ available_apis = []
911
+ for var, name in required_vars.items():
912
+ # Check both environment and HF Spaces secrets
913
+ if os.getenv(var) or os.getenv(f"HF_{var}"):
914
+ available_apis.append(name)
915
+ # Use HF secret if available
916
+ if os.getenv(f"HF_{var}") and not os.getenv(var):
917
+ os.environ[var] = os.getenv(f"HF_{var}")
918
+
919
+ if not available_apis:
920
+ print("⚠️ Warning: No API keys found.")
921
+ print("Set GROQ_API_KEY in HF Spaces secrets or environment")
922
+
923
+ # Create interface
924
+ interface = InvoiceRAGInterface()
925
+ demo = interface.create_interface()
926
+
927
+ print("πŸš€ Starting Invoice RAG System on Hugging Face Spaces...")
928
+
929
+ # HF Spaces optimized launch
930
+ demo.launch(
931
+ server_name="0.0.0.0",
932
+ server_port=7860,
933
+ share=False,
934
+ debug=False,
935
+ # Note: HF Spaces may not support all Gradio features
936
+ # Remove mcp_server=True if it causes issues
937
+ )
938
+
939
+ if __name__ == "__main__":
940
+ main()