Spaces:
Sleeping
Sleeping
| """ | |
| True End-to-End Agent Orchestrator | |
| =================================== | |
| Autonomous agent that: | |
| 1. Decides which tools to use based on document analysis | |
| 2. Validates its own output | |
| 3. Self-corrects when confidence is low | |
| 4. Learns from patterns | |
| """ | |
| import json | |
| import sys | |
| import logging | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Tuple | |
| from dataclasses import dataclass | |
| from enum import Enum | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s', | |
| handlers=[logging.StreamHandler(sys.stdout)], | |
| force=True | |
| ) | |
| logger = logging.getLogger(__name__) | |
| class AgentDecision(Enum): | |
| """Agent's possible decisions""" | |
| EXTRACT_TEXT = "extract_text" | |
| EXTRACT_TABLES = "extract_tables" | |
| RUN_NER = "run_ner" | |
| USE_GEMINI = "use_gemini" | |
| USE_REGEX = "use_regex" | |
| VALIDATE = "validate" | |
| RETRY = "retry" | |
| COMPLETE = "complete" | |
| HUMAN_REVIEW = "human_review" | |
| class AgentState: | |
| """Agent's internal state""" | |
| doc_id: str | |
| file_path: Path | |
| # Extracted data | |
| raw_text: Optional[str] = None | |
| tables: Optional[List] = None | |
| entities: Optional[List] = None | |
| entity_map: Optional[Dict] = None | |
| # Mapped fields | |
| fields: Optional[Dict] = None | |
| confidence_map: Optional[Dict] = None | |
| # Decision tracking | |
| attempts: int = 0 | |
| max_attempts: int = 3 | |
| history: List[str] = None | |
| errors: List[str] = None | |
| def __post_init__(self): | |
| if self.history is None: | |
| self.history = [] | |
| if self.errors is None: | |
| self.errors = [] | |
| class InvoiceAgent: | |
| """ | |
| Autonomous agent that processes invoices with self-correction. | |
| """ | |
| def __init__(self, text_extractor, table_extractor, ner_extractor, gemini_mapper): | |
| """ | |
| Args: | |
| text_extractor: Function(file_path) -> (success, text, error) | |
| table_extractor: Function(file_path) -> (success, tables, error) | |
| ner_extractor: Function(text) -> (success, entities, entity_map, error) | |
| gemini_mapper: Function(text, entities, entity_map, tables) -> (success, fields, error) | |
| """ | |
| self.text_extractor = text_extractor | |
| self.table_extractor = table_extractor | |
| self.ner_extractor = ner_extractor | |
| self.gemini_mapper = gemini_mapper | |
| # Minimum confidence thresholds | |
| self.MIN_CONFIDENCE = { | |
| 'cust_number': 0.6, | |
| 'posting_date': 0.7, | |
| 'total_open_amount': 0.7, | |
| 'cust_payment_terms': 0.5 | |
| } | |
| def process(self, state: AgentState) -> AgentState: | |
| """ | |
| Main agent loop - autonomous decision-making and execution. | |
| """ | |
| logger.info("=" * 70) | |
| logger.info(f"**** AGENT STARTING: {state.file_path.name}") | |
| logger.info("=" * 70) | |
| while state.attempts < state.max_attempts: | |
| state.attempts += 1 | |
| logger.info(f"\n**** ATTEMPT {state.attempts}/{state.max_attempts}") | |
| # Step 1: Decide next action | |
| decision = self._decide_next_action(state) | |
| logger.info(f"**** DECISION: {decision.value}") | |
| state.history.append(decision.value) | |
| # Step 2: Execute action | |
| success = self._execute_action(decision, state) | |
| if not success: | |
| logger.warning(f"**** Action {decision.value} failed") | |
| continue | |
| # Step 3: Check if we're done | |
| if decision == AgentDecision.COMPLETE: | |
| logger.info("**** AGENT COMPLETE") | |
| break | |
| if decision == AgentDecision.HUMAN_REVIEW: | |
| logger.info("**** AGENT REQUESTING HUMAN REVIEW") | |
| break | |
| logger.info("=" * 70) | |
| logger.info(f"**** Final confidence: {self._calculate_overall_confidence(state):.2f}") | |
| logger.info(f"**** Actions taken: {' → '.join(state.history)}") | |
| logger.info("=" * 70) | |
| return state | |
| def _decide_next_action(self, state: AgentState) -> AgentDecision: | |
| """ | |
| Agent's brain - decides what to do next based on current state. | |
| """ | |
| # 1. If no text, extract it | |
| if state.raw_text is None: | |
| return AgentDecision.EXTRACT_TEXT | |
| # 2. If text exists but no entities, run NER | |
| if state.entities is None: | |
| return AgentDecision.RUN_NER | |
| # 3. If no fields mapped yet, try Gemini first | |
| if state.fields is None: | |
| return AgentDecision.USE_GEMINI | |
| # 4. If fields exist, validate them | |
| if not self._is_validated(state): | |
| return AgentDecision.VALIDATE | |
| # 5. Check confidence - retry if low | |
| overall_confidence = self._calculate_overall_confidence(state) | |
| if overall_confidence < 0.6 and state.attempts < state.max_attempts: | |
| # Try alternative approach | |
| if 'use_gemini' in state.history and 'use_regex' not in state.history: | |
| return AgentDecision.USE_REGEX | |
| elif 'extract_tables' not in state.history: | |
| return AgentDecision.EXTRACT_TABLES | |
| else: | |
| return AgentDecision.RETRY | |
| # 6. If still low confidence, request human review | |
| if overall_confidence < 0.5: | |
| return AgentDecision.HUMAN_REVIEW | |
| # 7. Otherwise, we're done! | |
| return AgentDecision.COMPLETE | |
| def _execute_action(self, decision: AgentDecision, state: AgentState) -> bool: | |
| """Execute the decided action.""" | |
| try: | |
| if decision == AgentDecision.EXTRACT_TEXT: | |
| return self._extract_text(state) | |
| elif decision == AgentDecision.EXTRACT_TABLES: | |
| return self._extract_tables(state) | |
| elif decision == AgentDecision.RUN_NER: | |
| return self._run_ner(state) | |
| elif decision == AgentDecision.USE_GEMINI: | |
| return self._use_gemini(state) | |
| elif decision == AgentDecision.USE_REGEX: | |
| return self._use_regex(state) | |
| elif decision == AgentDecision.VALIDATE: | |
| return self._validate_fields(state) | |
| elif decision == AgentDecision.RETRY: | |
| # Clear fields and try again with different approach | |
| state.fields = None | |
| state.confidence_map = None | |
| return True | |
| elif decision in [AgentDecision.COMPLETE, AgentDecision.HUMAN_REVIEW]: | |
| return True | |
| return False | |
| except Exception as e: | |
| logger.error(f"**** Action failed: {e}") | |
| state.errors.append(str(e)) | |
| return False | |
| def _extract_text(self, state: AgentState) -> bool: | |
| """Extract text from document.""" | |
| logger.info("**** Extracting text...") | |
| success, text, error = self.text_extractor(state.file_path) | |
| if success and text and len(text.strip()) > 10: | |
| state.raw_text = text | |
| logger.info(f"**** Extracted {len(text)} characters") | |
| return True | |
| state.errors.append(f"Text extraction failed: {error}") | |
| return False | |
| def _extract_tables(self, state: AgentState) -> bool: | |
| """Extract tables from document.""" | |
| logger.info("**** Extracting tables...") | |
| success, tables, error = self.table_extractor(state.file_path) | |
| if success: | |
| state.tables = tables | |
| logger.info(f"**** Extracted {len(tables)} tables") | |
| return True | |
| logger.warning(f"**** Table extraction failed: {error}") | |
| state.tables = [] | |
| return True # Non-critical, continue | |
| def _run_ner(self, state: AgentState) -> bool: | |
| """Run Named Entity Recognition.""" | |
| logger.info("**** Running NER...") | |
| success, entities, entity_map, error = self.ner_extractor(state.raw_text) | |
| if success: | |
| state.entities = entities | |
| state.entity_map = entity_map | |
| logger.info(f"**** Found {len(entities)} entities") | |
| return True | |
| logger.warning(f"**** NER failed: {error}") | |
| state.entities = [] | |
| state.entity_map = {} | |
| return True # Non-critical, continue | |
| def _use_gemini(self, state: AgentState) -> bool: | |
| """Use Gemini for intelligent mapping.""" | |
| logger.info("**** Using Gemini mapping...") | |
| success, result, error = self.gemini_mapper( | |
| state.raw_text, | |
| state.entities or [], | |
| state.entity_map or {}, | |
| state.tables or [] | |
| ) | |
| if success and result: | |
| state.fields = { | |
| 'cust_number': result.get('customer_name', 'UNKNOWN')[:20], | |
| 'posting_date': result.get('date', '2024-01-01'), | |
| 'total_open_amount': float(result.get('total_amount', 0.0)), | |
| 'business_code': 'U001', | |
| 'cust_payment_terms': result.get('payment_terms', 'NAH4')[:10] | |
| } | |
| # High confidence from Gemini | |
| state.confidence_map = { | |
| 'cust_number': 0.9, | |
| 'posting_date': 0.9, | |
| 'total_open_amount': 0.9, | |
| 'business_code': 0.3, | |
| 'cust_payment_terms': 0.8 | |
| } | |
| logger.info(f"**** Gemini mapped: {state.fields}") | |
| return True | |
| logger.warning(f"**** Gemini failed: {error}") | |
| state.errors.append(f"Gemini mapping failed: {error}") | |
| return False | |
| def _use_regex(self, state: AgentState) -> bool: | |
| """Fallback regex-based extraction.""" | |
| logger.info("**** Using regex fallback...") | |
| from backend.app.api.ingest import map_with_regex | |
| fields, confidence = map_with_regex(state.raw_text, state.entities or []) | |
| state.fields = fields | |
| state.confidence_map = confidence | |
| logger.info(f"**** Regex mapped: {fields}") | |
| return True | |
| def _validate_fields(self, state: AgentState) -> bool: | |
| """ | |
| Validate extracted fields using business rules. | |
| Agent learns if data makes sense. | |
| """ | |
| logger.info("✓ Validating fields...") | |
| if not state.fields: | |
| return False | |
| validation_results = {} | |
| # 1. Customer number shouldn't be empty or generic | |
| cust = state.fields.get('cust_number', '') | |
| if cust and cust != 'UNKNOWN' and len(cust) > 2: | |
| validation_results['cust_number'] = True | |
| else: | |
| validation_results['cust_number'] = False | |
| logger.warning("**** Customer number looks invalid") | |
| # 2. Date should be reasonable (not default) | |
| date = state.fields.get('posting_date', '') | |
| if date and date != '2024-01-01': | |
| validation_results['posting_date'] = True | |
| else: | |
| validation_results['posting_date'] = False | |
| logger.warning("**** Date looks like default value") | |
| # 3. Amount should be > 0 | |
| amount = state.fields.get('total_open_amount', 0.0) | |
| if amount > 0: | |
| validation_results['total_open_amount'] = True | |
| else: | |
| validation_results['total_open_amount'] = False | |
| logger.warning("**** Amount is zero or missing") | |
| # Adjust confidence based on validation | |
| for field, is_valid in validation_results.items(): | |
| if not is_valid and state.confidence_map: | |
| state.confidence_map[field] *= 0.5 # Reduce confidence | |
| # Mark as validated | |
| state.history.append('validated') | |
| success_count = sum(validation_results.values()) | |
| logger.info(f"✓ Validation: {success_count}/{len(validation_results)} checks passed") | |
| return success_count >= 2 # At least 2 fields should be valid | |
| def _is_validated(self, state: AgentState) -> bool: | |
| """Check if validation has been performed.""" | |
| return 'validated' in state.history | |
| def _calculate_overall_confidence(self, state: AgentState) -> float: | |
| """Calculate overall confidence score.""" | |
| if not state.confidence_map: | |
| return 0.0 | |
| # Weighted average (important fields have more weight) | |
| weights = { | |
| 'cust_number': 0.3, | |
| 'posting_date': 0.2, | |
| 'total_open_amount': 0.3, | |
| 'cust_payment_terms': 0.1, | |
| 'business_code': 0.1 | |
| } | |
| total_confidence = 0.0 | |
| total_weight = 0.0 | |
| for field, weight in weights.items(): | |
| if field in state.confidence_map: | |
| total_confidence += state.confidence_map[field] * weight | |
| total_weight += weight | |
| return total_confidence / total_weight if total_weight > 0 else 0.0 | |
| # ============================================== | |
| # Integration with existing code | |
| # ============================================== | |
| def create_agent(text_extractor_fn, table_extractor_fn, ner_fn, gemini_fn): | |
| """ | |
| Factory function to create agent with your existing functions. | |
| Usage: | |
| from backend.app.api.ingest import ( | |
| call_text_extractor, call_table_extractor, | |
| call_ner, map_with_gemini | |
| ) | |
| agent = create_agent( | |
| call_text_extractor, | |
| call_table_extractor, | |
| call_ner, | |
| map_with_gemini | |
| ) | |
| state = AgentState(doc_id="doc123", file_path=Path("invoice.pdf")) | |
| result_state = agent.process(state) | |
| """ | |
| return InvoiceAgent(text_extractor_fn, table_extractor_fn, ner_fn, gemini_fn) | |
| def run_agent_pipeline(job_id: str, doc_id: str, file_path: Path): | |
| """ | |
| Replace your existing process_document() with this agentic version. | |
| """ | |
| from backend.app.api.ingest import ( | |
| call_text_extractor, call_table_extractor, | |
| call_ner, map_with_gemini, | |
| save_extraction, save_invoice_fields, | |
| update_job_status | |
| ) | |
| try: | |
| update_job_status(job_id, 'processing') | |
| # Create agent | |
| agent = create_agent( | |
| call_text_extractor, | |
| call_table_extractor, | |
| call_ner, | |
| map_with_gemini | |
| ) | |
| # Initialize state | |
| state = AgentState(doc_id=doc_id, file_path=file_path) | |
| # Let agent decide and execute autonomously | |
| result_state = agent.process(state) | |
| # Save results | |
| if result_state.fields: | |
| save_extraction( | |
| doc_id, | |
| result_state.raw_text, | |
| result_state.tables or [], | |
| result_state.entities or [], | |
| { | |
| 'method': 'autonomous_agent', | |
| 'attempts': result_state.attempts, | |
| 'actions': result_state.history, | |
| 'confidence': agent._calculate_overall_confidence(result_state) | |
| }, | |
| None | |
| ) | |
| save_invoice_fields( | |
| doc_id, | |
| result_state.fields, | |
| result_state.confidence_map or {} | |
| ) | |
| # Check if needs human review | |
| if AgentDecision.HUMAN_REVIEW.value in result_state.history: | |
| update_job_status(job_id, 'needs_review') | |
| else: | |
| update_job_status(job_id, 'completed') | |
| logger.info(f"**** Agent completed with {len(result_state.history)} actions") | |
| else: | |
| update_job_status(job_id, 'failed', 'Agent could not extract fields') | |
| except Exception as e: | |
| logger.error(f"**** Agent failed: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| update_job_status(job_id, 'failed', str(e)) |