Spaces:
Sleeping
Sleeping
| import os | |
| import uuid | |
| import json | |
| import sqlite3 | |
| import logging | |
| import csv | |
| from pathlib import Path | |
| from datetime import datetime | |
| from typing import Optional, Dict, List, Any | |
| from fastapi import APIRouter, UploadFile, File, HTTPException, BackgroundTasks, Form | |
| from pydantic import BaseModel | |
| from filelock import FileLock | |
| import httpx | |
| import re | |
| import sys | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s', | |
| handlers=[logging.StreamHandler(sys.stdout)], | |
| force=True | |
| ) | |
| logger = logging.getLogger(__name__) | |
| BASE_DIR = Path(__file__).parent.parent.parent.parent | |
| STORAGE_PATH = Path(os.getenv('STORAGE_PATH', str(BASE_DIR / "data" / "docs"))) | |
| DB_PATH = Path(os.getenv('DB_PATH', str(BASE_DIR / "data" / "invoices.db"))) | |
| LOCK_PATH = BASE_DIR / "data" / "invoices.db.lock" | |
| PREDICT_ENDPOINT = 'http://localhost:7860/predict' | |
| STORAGE_PATH.mkdir(parents=True, exist_ok=True) | |
| router = APIRouter(prefix="/api", tags=["ingest"]) | |
| def _init_db_tables(): | |
| """Create tables on module import - ensures HF Space has tables""" | |
| try: | |
| logger.info("🔍 Checking if database tables exist...") | |
| DB_PATH.parent.mkdir(parents=True, exist_ok=True) | |
| with FileLock(str(LOCK_PATH), timeout=10): | |
| conn = sqlite3.connect(str(DB_PATH)) | |
| cursor = conn.cursor() | |
| # Quick check | |
| cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='ingest_jobs'") | |
| if cursor.fetchone(): | |
| conn.close() | |
| logger.info("✅ Database tables already exist") | |
| return | |
| logger.warning("⚠️ Database tables not found, creating...") | |
| # Create all tables | |
| tables_sql = [ | |
| """CREATE TABLE IF NOT EXISTS ingest_jobs ( | |
| job_id TEXT PRIMARY KEY, | |
| doc_id TEXT, | |
| filename TEXT NOT NULL, | |
| status TEXT NOT NULL DEFAULT 'queued', | |
| error_text TEXT, | |
| created_at TEXT DEFAULT CURRENT_TIMESTAMP, | |
| updated_at TEXT DEFAULT CURRENT_TIMESTAMP | |
| )""", | |
| """CREATE TABLE IF NOT EXISTS documents ( | |
| doc_id TEXT PRIMARY KEY, | |
| job_id TEXT NOT NULL, | |
| path TEXT NOT NULL, | |
| filename TEXT NOT NULL, | |
| content_type TEXT NOT NULL, | |
| uploaded_at TEXT DEFAULT CURRENT_TIMESTAMP, | |
| FOREIGN KEY (job_id) REFERENCES ingest_jobs(job_id) | |
| )""", | |
| """CREATE TABLE IF NOT EXISTS extractions ( | |
| doc_id TEXT PRIMARY KEY, | |
| raw_text TEXT, | |
| tables_json TEXT, | |
| entities_json TEXT, | |
| classification_json TEXT, | |
| summary_text TEXT, | |
| extracted_at TEXT DEFAULT CURRENT_TIMESTAMP, | |
| FOREIGN KEY (doc_id) REFERENCES documents(doc_id) | |
| )""", | |
| """CREATE TABLE IF NOT EXISTS invoice_fields ( | |
| invoice_id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| doc_id TEXT NOT NULL, | |
| cust_number TEXT, | |
| posting_date TEXT, | |
| total_open_amount REAL, | |
| business_code TEXT, | |
| cust_payment_terms TEXT, | |
| confidence_map TEXT, | |
| created_at TEXT DEFAULT CURRENT_TIMESTAMP, | |
| FOREIGN KEY (doc_id) REFERENCES documents(doc_id) | |
| )""", | |
| """CREATE TABLE IF NOT EXISTS batch_jobs ( | |
| batch_id TEXT PRIMARY KEY, | |
| total_files INTEGER, | |
| message TEXT, | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
| )""", | |
| """CREATE TABLE IF NOT EXISTS batch_job_mapping ( | |
| batch_id TEXT, | |
| job_id TEXT, | |
| FOREIGN KEY (batch_id) REFERENCES batch_jobs(batch_id), | |
| FOREIGN KEY (job_id) REFERENCES ingest_jobs(job_id) | |
| )""", | |
| # ML feature tables | |
| """CREATE TABLE IF NOT EXISTS customer_aggregates ( | |
| cust_number TEXT PRIMARY KEY, | |
| cust_avg_days REAL, | |
| cust_median_days REAL, | |
| cust_invoice_count INTEGER, | |
| last_updated TEXT DEFAULT CURRENT_TIMESTAMP | |
| )""", | |
| """CREATE TABLE IF NOT EXISTS payment_terms_aggregates ( | |
| cust_payment_terms TEXT PRIMARY KEY, | |
| terms_avg_days REAL, | |
| terms_median_days REAL, | |
| terms_invoice_count INTEGER, | |
| last_updated TEXT DEFAULT CURRENT_TIMESTAMP | |
| )""", | |
| """CREATE TABLE IF NOT EXISTS business_code_aggregates ( | |
| business_code TEXT PRIMARY KEY, | |
| bc_avg_days REAL, | |
| bc_median_days REAL, | |
| bc_invoice_count INTEGER, | |
| last_updated TEXT DEFAULT CURRENT_TIMESTAMP | |
| )""", | |
| """CREATE TABLE IF NOT EXISTS predictions_log ( | |
| prediction_id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| invoice_id INTEGER, | |
| cust_number TEXT, | |
| posting_date TEXT, | |
| total_open_amount REAL, | |
| business_code TEXT, | |
| cust_payment_terms TEXT, | |
| predicted_days_to_clear REAL, | |
| predicted_clear_date TEXT, | |
| model_version TEXT, | |
| features_json TEXT, | |
| predicted_at TEXT DEFAULT CURRENT_TIMESTAMP | |
| )""" | |
| ] | |
| # Execute all CREATE TABLE statements | |
| for sql in tables_sql: | |
| cursor.execute(sql) | |
| # Create indexes | |
| indexes_sql = [ | |
| "CREATE INDEX IF NOT EXISTS idx_ingest_jobs_status ON ingest_jobs(status)", | |
| "CREATE INDEX IF NOT EXISTS idx_ingest_jobs_created ON ingest_jobs(created_at DESC)", | |
| "CREATE INDEX IF NOT EXISTS idx_documents_job_id ON documents(job_id)", | |
| "CREATE INDEX IF NOT EXISTS idx_invoice_fields_doc_id ON invoice_fields(doc_id)", | |
| "CREATE INDEX IF NOT EXISTS idx_batch_mapping_batch ON batch_job_mapping(batch_id)", | |
| "CREATE INDEX IF NOT EXISTS idx_predictions_log_cust ON predictions_log(cust_number)" | |
| ] | |
| for sql in indexes_sql: | |
| cursor.execute(sql) | |
| conn.commit() | |
| conn.close() | |
| logger.info("✅ Database tables created successfully!") | |
| except Exception as e: | |
| logger.error(f"❌ Failed to create tables: {e}") | |
| import traceback | |
| logger.error(traceback.format_exc()) | |
| # Run on module import | |
| logger.info("🔍 Initializing database on module load...") | |
| try: | |
| _init_db_tables() | |
| logger.info("✅ Database initialization complete") | |
| except Exception as e: | |
| logger.error(f"❌ Database initialization failed: {e}") | |
| logger.warning("⚠️ Application may not work correctly!") | |
| # ============================================ | |
| # LOCAL OCR FALLBACK (NEW) | |
| # ============================================ | |
| # ============================================ | |
| # LOCAL OCR FALLBACK (UPDATED - EasyOCR + Tesseract) | |
| # ============================================ | |
| def extract_text_with_easyocr(file_path: Path) -> tuple: | |
| """ | |
| EasyOCR - Best free open-source OCR | |
| - Works offline | |
| - 80+ languages | |
| - GPU/CPU support | |
| - Better accuracy than Tesseract for invoices | |
| """ | |
| try: | |
| import easyocr | |
| logger.info("🔧 Using EasyOCR (best free OCR)...") | |
| # Initialize reader (downloads models on first run) | |
| # Use GPU if available, fallback to CPU | |
| reader = easyocr.Reader(['en'], gpu=False) # Set gpu=True if you have CUDA | |
| # Read image | |
| result = reader.readtext(str(file_path), detail=0, paragraph=True) | |
| # Join all text | |
| text = '\n'.join(result) | |
| if text and len(text.strip()) >= 10: | |
| logger.info(f"✅ EasyOCR extracted {len(text)} characters") | |
| return True, text, None | |
| return False, None, "EasyOCR produced no usable text" | |
| except ImportError: | |
| logger.warning("⚠️ easyocr not installed. Install with: pip install easyocr") | |
| return False, None, "easyocr not available" | |
| except Exception as e: | |
| logger.error(f"❌ EasyOCR failed: {e}") | |
| return False, None, str(e) | |
| def extract_text_with_tesseract(file_path: Path) -> tuple: | |
| """ | |
| Tesseract OCR - Fallback option | |
| Faster but less accurate than EasyOCR | |
| """ | |
| try: | |
| import pytesseract | |
| from PIL import Image | |
| logger.info("🔧 Using Tesseract OCR as secondary fallback...") | |
| image = Image.open(file_path) | |
| text = pytesseract.image_to_string(image) | |
| if text and len(text.strip()) >= 10: | |
| logger.info(f"✅ Tesseract extracted {len(text)} characters") | |
| return True, text, None | |
| return False, None, "Tesseract produced no usable text" | |
| except ImportError: | |
| logger.warning("⚠️ pytesseract not installed. Install with: pip install pytesseract pillow") | |
| return False, None, "pytesseract not available" | |
| except Exception as e: | |
| logger.error(f"❌ Tesseract failed: {e}") | |
| return False, None, str(e) | |
| def extract_text_with_local_ocr(file_path: Path) -> tuple: | |
| """ | |
| Multi-tier local OCR fallback system: | |
| 1. Try EasyOCR (best accuracy) | |
| 2. Try Tesseract (faster, less accurate) | |
| 3. Give up | |
| """ | |
| logger.info("=" * 70) | |
| logger.info("🔄 HF extraction failed - trying local OCR fallbacks...") | |
| logger.info("=" * 70) | |
| # Priority 1: EasyOCR (best for invoices) | |
| success, text, error = extract_text_with_easyocr(file_path) | |
| if success: | |
| logger.info("✅ EasyOCR succeeded!") | |
| return True, text, None | |
| else: | |
| logger.warning(f"⚠️ EasyOCR failed: {error}") | |
| # Priority 2: Tesseract (faster fallback) | |
| success, text, error = extract_text_with_tesseract(file_path) | |
| if success: | |
| logger.info("✅ Tesseract succeeded!") | |
| return True, text, None | |
| else: | |
| logger.warning(f"⚠️ Tesseract failed: {error}") | |
| # All local OCR failed | |
| logger.error("❌ All local OCR methods failed") | |
| return False, None, "All local OCR methods failed" | |
| # ============================================ | |
| # STEP 1: HF Agent Text Extraction (UPDATED) | |
| # ============================================ | |
| def get_agent_headers(): | |
| """Get headers with HF token""" | |
| token = ( | |
| os.getenv('HF_TOKEN') or | |
| os.getenv('HUGGINGFACE_API_TOKEN') or | |
| os.getenv('AGENT_BEARER_TOKEN') or | |
| '' | |
| ) | |
| return {'Authorization': f'Bearer {token}'} if token else {} | |
| def get_mime_type(file_path: Path) -> str: | |
| """Get MIME type""" | |
| ext = file_path.suffix.lower() | |
| mime_map = { | |
| '.pdf': 'application/pdf', | |
| '.jpg': 'image/jpeg', | |
| '.jpeg': 'image/jpeg', | |
| '.png': 'image/png' | |
| } | |
| return mime_map.get(ext, 'application/octet-stream') | |
| def call_text_extractor(file_path: Path, max_retries=3): | |
| """ | |
| HF text extraction with retry logic and exponential backoff. | |
| Falls back to local OCR if all retries fail. | |
| """ | |
| url = os.getenv('TEXT_EXTRACTOR_URL', 'https://point9-extract-text-and-table.hf.space/api/text') | |
| base_timeout = int(os.getenv('AGENT_TIMEOUT_SECONDS', '120')) | |
| for attempt in range(max_retries): | |
| # Progressive timeout: 120s, 180s, 240s | |
| timeout = base_timeout + (60 * attempt) | |
| try: | |
| logger.info(f"📄 Extracting text from {file_path.name} (attempt {attempt + 1}/{max_retries}, timeout={timeout}s)...") | |
| filename = file_path.name | |
| mime_type = get_mime_type(file_path) | |
| with open(file_path, 'rb') as f: | |
| files = {'file': (filename, f, mime_type)} | |
| data = { | |
| 'filename': filename, | |
| 'start_page': 1, | |
| 'end_page': 1 | |
| } | |
| headers = get_agent_headers() | |
| response = httpx.post(url, files=files, data=data, headers=headers, timeout=timeout) | |
| if response.status_code == 200: | |
| result = response.json() | |
| text = result.get('result') or result.get('text') or result.get('extracted_text') or '' | |
| if text and len(text.strip()) >= 10: | |
| logger.info(f"✅ Extracted {len(text)} characters") | |
| return True, text, None | |
| logger.warning("⚠️ No text extracted from response") | |
| if attempt < max_retries - 1: | |
| continue | |
| return False, None, "No text extracted" | |
| logger.warning(f"⚠️ HTTP {response.status_code}: {response.text[:200]}") | |
| except httpx.TimeoutException: | |
| logger.warning(f"⚠️ Timeout after {timeout}s on attempt {attempt + 1}") | |
| if attempt < max_retries - 1: | |
| logger.info("🔄 Retrying with longer timeout...") | |
| continue | |
| except Exception as e: | |
| logger.error(f"❌ Error on attempt {attempt + 1}: {e}") | |
| if attempt < max_retries - 1: | |
| logger.info("🔄 Retrying...") | |
| continue | |
| # All retries failed - try local OCR fallback | |
| logger.warning(f"⚠️ All {max_retries} HF extraction attempts failed, trying local OCR fallback...") | |
| return extract_text_with_local_ocr(file_path) | |
| def call_table_extractor(file_path: Path, max_retries=2): | |
| """ | |
| HF table extraction with retry logic. | |
| Non-critical, so fewer retries. | |
| """ | |
| url = os.getenv('TABLE_EXTRACTOR_URL', 'https://point9-extract-text-and-table.hf.space/api/tables') | |
| base_timeout = int(os.getenv('AGENT_TIMEOUT_SECONDS', '120')) | |
| for attempt in range(max_retries): | |
| timeout = base_timeout + (60 * attempt) | |
| try: | |
| logger.info(f"📊 Extracting tables from {file_path.name} (attempt {attempt + 1}/{max_retries})...") | |
| filename = file_path.name | |
| mime_type = get_mime_type(file_path) | |
| with open(file_path, 'rb') as f: | |
| files = {'file': (filename, f, mime_type)} | |
| data = { | |
| 'filename': filename, | |
| 'start_page': 1, | |
| 'end_page': 1 | |
| } | |
| headers = get_agent_headers() | |
| response = httpx.post(url, files=files, data=data, headers=headers, timeout=timeout) | |
| if response.status_code == 200: | |
| result = response.json() | |
| tables = result.get('result') or result.get('tables') or [] | |
| logger.info(f"✅ Extracted {len(tables)} tables") | |
| return True, tables, None | |
| logger.warning(f"⚠️ HTTP {response.status_code}") | |
| except httpx.TimeoutException: | |
| logger.warning(f"⚠️ Table extraction timeout on attempt {attempt + 1}") | |
| except Exception as e: | |
| logger.warning(f"⚠️ Table extraction error: {e}") | |
| # Non-critical - return empty list | |
| logger.info("ℹ️ Table extraction failed, continuing without tables") | |
| return False, [], "Table extraction failed (non-critical)" | |
| # ============================================ | |
| # STEP 2: HF NER (Named Entity Recognition) | |
| # ============================================ | |
| def call_ner(text: str, file_path: Path = None, max_retries=2) -> tuple: | |
| """ | |
| Extract named entities using HF NER agent with retry logic. | |
| """ | |
| url = os.getenv('NER_URL', 'https://point9-ner.hf.space/api/ner') | |
| base_timeout = int(os.getenv('AGENT_TIMEOUT_SECONDS', '120')) | |
| for attempt in range(max_retries): | |
| timeout = base_timeout + (30 * attempt) | |
| try: | |
| logger.info(f"🔍 Running NER to find entities (attempt {attempt + 1}/{max_retries})...") | |
| headers = get_agent_headers() | |
| # NER expects multipart/form-data with file OR text | |
| if file_path and file_path.exists(): | |
| # Send file | |
| filename = file_path.name | |
| mime_type = get_mime_type(file_path) | |
| with open(file_path, 'rb') as f: | |
| files = {'file': (filename, f, mime_type)} | |
| data = { | |
| 'text': text[:5000], | |
| 'filename': filename, | |
| 'start_page': 1, | |
| 'end_page': 1 | |
| } | |
| response = httpx.post(url, files=files, data=data, headers=headers, timeout=timeout) | |
| else: | |
| # Send just text as form data | |
| data = { | |
| 'text': text[:5000], | |
| 'filename': 'document.txt', | |
| 'start_page': 1, | |
| 'end_page': 1 | |
| } | |
| response = httpx.post(url, data=data, headers=headers, timeout=timeout) | |
| if response.status_code == 200: | |
| result = response.json() | |
| # FIX: Handle both dict and string responses | |
| if isinstance(result, str): | |
| try: | |
| result = json.loads(result) | |
| except: | |
| logger.warning(f"⚠️ NER returned unparseable string: {result[:100]}") | |
| if attempt < max_retries - 1: | |
| continue | |
| return False, [], {}, "Invalid response format" | |
| # Extract entities | |
| entities = result.get('entities') or result.get('result') or [] | |
| # Handle case where entities might also be a string | |
| if isinstance(entities, str): | |
| try: | |
| entities = json.loads(entities) | |
| except: | |
| entities = [] | |
| logger.info(f"✅ Found {len(entities)} entities") | |
| # Group entities by type | |
| entity_map = { | |
| 'PERSON': [], | |
| 'ORG': [], | |
| 'DATE': [], | |
| 'MONEY': [], | |
| 'CARDINAL': [] | |
| } | |
| for entity in entities: | |
| if not isinstance(entity, dict): | |
| continue | |
| ent_type = entity.get('entity_type') or entity.get('label') | |
| ent_text = entity.get('text') or entity.get('word') | |
| if ent_type in entity_map and ent_text: | |
| entity_map[ent_type].append(ent_text) | |
| logger.info(f"📋 Entity summary: PERSON={len(entity_map['PERSON'])}, ORG={len(entity_map['ORG'])}, DATE={len(entity_map['DATE'])}, MONEY={len(entity_map['MONEY'])}") | |
| return True, entities, entity_map, None | |
| logger.warning(f"⚠️ NER HTTP {response.status_code}") | |
| except httpx.TimeoutException: | |
| logger.warning(f"⚠️ NER timeout on attempt {attempt + 1}") | |
| except Exception as e: | |
| logger.error(f"❌ NER error on attempt {attempt + 1}: {e}") | |
| # NER failed - return empty (non-critical) | |
| logger.warning("⚠️ NER failed after retries, continuing without entities") | |
| return False, [], {}, "NER failed (non-critical)" | |
| # ============================================ | |
| # STEP 3: Gemini Intelligent Mapping | |
| # ============================================ | |
| def map_with_gemini(text: str, entities: List, entity_map: Dict, tables: List): | |
| """Use Gemini to intelligently map extracted data to invoice fields""" | |
| try: | |
| import google.generativeai as genai | |
| api_key = os.getenv('GEMINI_API_KEY') | |
| if not api_key: | |
| logger.warning("⚠️ No Gemini API key configured") | |
| return False, None, "No Gemini API key" | |
| logger.info("🧠 Using Gemini for intelligent field mapping...") | |
| genai.configure(api_key=api_key) | |
| model = genai.GenerativeModel('models/gemini-2.5-flash') | |
| # Build context for Gemini | |
| context = f""" | |
| EXTRACTED TEXT: | |
| {text[:3000]} | |
| NAMED ENTITIES FOUND: | |
| - Organizations: {entity_map.get('ORG', [])} | |
| - People: {entity_map.get('PERSON', [])} | |
| - Dates: {entity_map.get('DATE', [])} | |
| - Money amounts: {entity_map.get('MONEY', [])} | |
| - Numbers: {entity_map.get('CARDINAL', [])} | |
| TABLES: | |
| {json.dumps(tables[:2], indent=2) if tables else 'None'} | |
| """ | |
| prompt = f"""You are an expert at analyzing invoice data. Given the extracted text and entities below, map them to invoice fields. | |
| {context} | |
| Analyze the above data and return ONLY a valid JSON object with these exact fields: | |
| {{ | |
| "customer_name": "the client/customer company name (check ORG entities first)", | |
| "invoice_number": "the invoice number (check CARDINAL entities)", | |
| "date": "invoice date in YYYY-MM-DD format (check DATE entities)", | |
| "total_amount": numeric total amount only (check MONEY entities, no currency symbol), | |
| "payment_terms": "payment terms like NET30, NET60, or NAH4 if not found", | |
| "reasoning": "brief explanation of how you identified each field" | |
| }} | |
| Rules: | |
| 1. Prefer entities over raw text when available | |
| 2. Customer name is usually the first ORG after "Bill To" or "Client" | |
| 3. Total amount is usually the largest MONEY value | |
| 4. Date should be in YYYY-MM-DD format | |
| 5. If uncertain, use these defaults: customer_name="UNKNOWN", date="2024-01-01", total_amount=0.0, payment_terms="NAH4" | |
| Return ONLY the JSON object, no markdown, no explanation outside the JSON.""" | |
| response = model.generate_content(prompt) | |
| text_response = response.text.strip() | |
| # Remove markdown if present | |
| text_response = text_response.replace('```json', '').replace('```', '').strip() | |
| result = json.loads(text_response) | |
| logger.info(f"✅ Gemini mapped: Customer={result.get('customer_name')}, Amount=${result.get('total_amount')}") | |
| logger.info(f"💡 Reasoning: {result.get('reasoning', 'N/A')[:100]}") | |
| return True, result, None | |
| except json.JSONDecodeError as e: | |
| logger.error(f"❌ Gemini returned invalid JSON: {e}") | |
| logger.error(f"Response: {text_response[:500]}") | |
| return False, None, f"Invalid JSON: {e}" | |
| except Exception as e: | |
| logger.error(f"❌ Gemini mapping failed: {e}") | |
| import traceback | |
| logger.error(traceback.format_exc()) | |
| return False, None, str(e) | |
| # ============================================ | |
| # Fallback: Regex Mapping | |
| # ============================================ | |
| def map_with_regex(text: str, entities: List) -> tuple: | |
| """Fallback regex-based field extraction""" | |
| logger.info("🔤 Using regex fallback for field mapping...") | |
| fields = {} | |
| confidence = {} | |
| # CUSTOMER NAME - try to use ORG entities first | |
| org_entities = [e.get('text') or e.get('word') for e in entities | |
| if (e.get('entity_type') or e.get('label')) == 'ORG'] | |
| if org_entities: | |
| fields['cust_number'] = org_entities[0][:20] | |
| confidence['cust_number'] = 0.8 | |
| else: | |
| # Regex fallback | |
| client_patterns = [ | |
| r'(?:Client|Bill\s+To|Customer)[:\s]+(.*?)(?:\n|Tax|IBAN)', | |
| r'(?:customer|client)[\s:]+([A-Za-z][A-Za-z\s,&-]+?)(?:\n|$)', | |
| ] | |
| for pattern in client_patterns: | |
| match = re.search(pattern, text, re.IGNORECASE | re.MULTILINE) | |
| if match: | |
| client = match.group(1).strip() | |
| words = [w.strip() for w in client.replace(',', ' ').split() if len(w.strip()) > 2] | |
| if words: | |
| fields['cust_number'] = words[0][:20] | |
| confidence['cust_number'] = 0.6 | |
| break | |
| if 'cust_number' not in fields: | |
| fields['cust_number'] = 'UNKNOWN' | |
| confidence['cust_number'] = 0.1 | |
| # DATE - try DATE entities first | |
| date_entities = [e.get('text') or e.get('word') for e in entities | |
| if (e.get('entity_type') or e.get('label')) == 'DATE'] | |
| if date_entities: | |
| date_str = date_entities[0] | |
| for fmt in ['%m/%d/%Y', '%d/%m/%Y', '%Y-%m-%d', '%m-%d-%Y']: | |
| try: | |
| dt = datetime.strptime(date_str, fmt) | |
| fields['posting_date'] = dt.strftime('%Y-%m-%d') | |
| confidence['posting_date'] = 0.8 | |
| break | |
| except: | |
| continue | |
| if 'posting_date' not in fields: | |
| date_patterns = [ | |
| r'(?:Date\s+of\s+issue|Invoice\s+Date|Date)[:\s]+(\d{1,2}[/-]\d{1,2}[/-]\d{2,4})', | |
| ] | |
| for pattern in date_patterns: | |
| match = re.search(pattern, text, re.IGNORECASE) | |
| if match: | |
| date_str = match.group(1) | |
| for fmt in ['%m/%d/%Y', '%d/%m/%Y']: | |
| try: | |
| dt = datetime.strptime(date_str, fmt) | |
| fields['posting_date'] = dt.strftime('%Y-%m-%d') | |
| confidence['posting_date'] = 0.7 | |
| break | |
| except: | |
| continue | |
| if 'posting_date' in fields: | |
| break | |
| if 'posting_date' not in fields: | |
| fields['posting_date'] = datetime.now().strftime('%Y-%m-%d') | |
| confidence['posting_date'] = 0.1 | |
| # AMOUNT - try MONEY entities first | |
| money_entities = [e.get('text') or e.get('word') for e in entities | |
| if (e.get('entity_type') or e.get('label')) == 'MONEY'] | |
| if money_entities: | |
| amounts = [] | |
| for money_str in money_entities: | |
| try: | |
| # Remove currency symbols and parse | |
| amt_str = re.sub(r'[^\d.]', '', money_str) | |
| amt = float(amt_str) | |
| if amt > 10: | |
| amounts.append(amt) | |
| except: | |
| pass | |
| if amounts: | |
| fields['total_open_amount'] = max(amounts) | |
| confidence['total_open_amount'] = 0.8 | |
| logger.info(f"✅ Found amount from MONEY entity: ${fields['total_open_amount']}") | |
| if 'total_open_amount' not in fields: | |
| # Regex fallback | |
| pattern = r'\$\s*([0-9]{1,3}(?:,?[0-9]{3})*\.[0-9]{2})' | |
| amounts = [] | |
| for match in re.finditer(pattern, text): | |
| try: | |
| amt = float(match.group(1).replace(',', '')) | |
| if amt > 50: | |
| amounts.append(amt) | |
| except: | |
| pass | |
| if amounts: | |
| fields['total_open_amount'] = max(amounts) | |
| confidence['total_open_amount'] = 0.6 | |
| else: | |
| fields['total_open_amount'] = 0.0 | |
| confidence['total_open_amount'] = 0.0 | |
| logger.warning("⚠️ No amount found!") | |
| # PAYMENT TERMS | |
| terms_match = re.search(r'(NET\s?\d{1,2}|N\d{2}|NAH\d)', text, re.IGNORECASE) | |
| fields['cust_payment_terms'] = terms_match.group(1).upper() if terms_match else 'NAH4' | |
| confidence['cust_payment_terms'] = 0.7 if terms_match else 0.2 | |
| # BUSINESS CODE | |
| fields['business_code'] = 'U001' | |
| confidence['business_code'] = 0.2 | |
| return fields, confidence | |
| # ============================================ | |
| # Database Functions | |
| # ============================================ | |
| def update_job_status(job_id: str, status: str, error_text: str = None): | |
| """Update job status""" | |
| with FileLock(str(LOCK_PATH), timeout=10): | |
| conn = sqlite3.connect(str(DB_PATH)) | |
| cursor = conn.cursor() | |
| cursor.execute(""" | |
| UPDATE ingest_jobs | |
| SET status = ?, error_text = ?, updated_at = CURRENT_TIMESTAMP | |
| WHERE job_id = ? | |
| """, (status, error_text, job_id)) | |
| conn.commit() | |
| conn.close() | |
| def save_extraction(doc_id: str, raw_text: str, tables: list, entities: list, classification: dict, summary: str = None): | |
| """Save extraction results""" | |
| with FileLock(str(LOCK_PATH), timeout=10): | |
| conn = sqlite3.connect(str(DB_PATH)) | |
| cursor = conn.cursor() | |
| cursor.execute(""" | |
| INSERT OR REPLACE INTO extractions ( | |
| doc_id, raw_text, tables_json, entities_json, | |
| classification_json, summary_text | |
| ) VALUES (?, ?, ?, ?, ?, ?) | |
| """, ( | |
| doc_id, | |
| raw_text, | |
| json.dumps(tables) if tables else None, | |
| json.dumps(entities) if entities else None, | |
| json.dumps(classification) if classification else None, | |
| summary | |
| )) | |
| conn.commit() | |
| conn.close() | |
| def save_invoice_fields(doc_id: str, fields: Dict, confidence_map: Dict): | |
| """Save invoice fields""" | |
| with FileLock(str(LOCK_PATH), timeout=10): | |
| conn = sqlite3.connect(str(DB_PATH)) | |
| cursor = conn.cursor() | |
| cursor.execute(""" | |
| INSERT INTO invoice_fields ( | |
| doc_id, cust_number, posting_date, total_open_amount, | |
| business_code, cust_payment_terms, confidence_map | |
| ) VALUES (?, ?, ?, ?, ?, ?, ?) | |
| """, ( | |
| doc_id, | |
| fields.get('cust_number'), | |
| fields.get('posting_date'), | |
| fields.get('total_open_amount'), | |
| fields.get('business_code'), | |
| fields.get('cust_payment_terms'), | |
| json.dumps(confidence_map) | |
| )) | |
| conn.commit() | |
| conn.close() | |
| # ============================================ | |
| # AGENT MODE FLAG (Environment Variable) | |
| # ============================================ | |
| USE_AGENT_MODE = os.getenv('USE_AGENT_MODE', 'true').lower() == 'true' | |
| # ============================================ | |
| # Main Processing Pipeline | |
| # ============================================ | |
| def process_document_legacy(job_id: str, doc_id: str, file_path: Path): | |
| """ | |
| LEGACY PIPELINE (Original Implementation): | |
| 1. HF Extract text + tables | |
| 2. HF NER finds entities | |
| 3. Gemini maps to invoice fields | |
| """ | |
| logger.info("=" * 70) | |
| logger.info(f"🚀 Starting LEGACY pipeline for {file_path.name}") | |
| logger.info("=" * 70) | |
| try: | |
| update_job_status(job_id, 'processing') | |
| # STEP 1: Extract text with HF agents | |
| logger.info("STEP 1: HF TEXT + TABLE EXTRACTION") | |
| logger.info("-" * 70) | |
| success, raw_text, error = call_text_extractor(file_path) | |
| if not success or not raw_text: | |
| update_job_status(job_id, 'failed', f"Text extraction failed: {error}") | |
| return | |
| # Extract tables (optional, won't fail if it doesn't work) | |
| _, tables, _ = call_table_extractor(file_path) | |
| # STEP 2: NER to find entities | |
| logger.info("-" * 70) | |
| logger.info("STEP 2: NER - NAMED ENTITY RECOGNITION") | |
| logger.info("-" * 70) | |
| ner_success, entities, entity_map, ner_error = call_ner(raw_text, file_path) | |
| if not ner_success: | |
| logger.warning(f"⚠️ NER failed: {ner_error}, continuing without entities") | |
| entities = [] | |
| entity_map = {} | |
| # STEP 3: Gemini intelligent mapping | |
| logger.info("-" * 70) | |
| logger.info("STEP 3: GEMINI INTELLIGENT MAPPING") | |
| logger.info("-" * 70) | |
| gemini_success, gemini_result, gemini_error = map_with_gemini( | |
| raw_text, entities, entity_map, tables | |
| ) | |
| if gemini_success and gemini_result: | |
| # Use Gemini's mapping | |
| fields = { | |
| 'cust_number': gemini_result.get('customer_name', 'UNKNOWN')[:20], | |
| 'posting_date': gemini_result.get('date', datetime.now().strftime('%Y-%m-%d')), | |
| 'total_open_amount': float(gemini_result.get('total_amount', 0.0)), | |
| 'business_code': 'U001', | |
| 'cust_payment_terms': gemini_result.get('payment_terms', 'NAH4')[:10] | |
| } | |
| confidence_map = { | |
| 'cust_number': 0.95, | |
| 'posting_date': 0.95, | |
| 'total_open_amount': 0.95, | |
| 'business_code': 0.2, | |
| 'cust_payment_terms': 0.8 | |
| } | |
| method = 'hf_ner_gemini' | |
| else: | |
| # Fallback to regex mapping | |
| logger.warning(f"⚠️ Gemini mapping failed: {gemini_error}") | |
| logger.info("-" * 70) | |
| logger.info("FALLBACK: REGEX MAPPING") | |
| logger.info("-" * 70) | |
| fields, confidence_map = map_with_regex(raw_text, entities) | |
| method = 'hf_ner_regex' | |
| # Save results | |
| save_extraction( | |
| doc_id, raw_text, tables, entities, | |
| {'method': method, 'entity_count': len(entities)}, | |
| None | |
| ) | |
| save_invoice_fields(doc_id, fields, confidence_map) | |
| logger.info("=" * 70) | |
| logger.info(f"✅ EXTRACTION COMPLETE - Method: {method}") | |
| logger.info(f"📋 Fields: {fields}") | |
| logger.info("=" * 70) | |
| # Call prediction API | |
| #logger.info("🔮 Calling payment prediction...") | |
| #try: | |
| # pred_response = httpx.post(PREDICT_ENDPOINT, json=fields, timeout=30) | |
| # | |
| # if pred_response.status_code == 200: | |
| # pred_result = pred_response.json() | |
| # logger.info(f"✅ Prediction: {pred_result.get('predicted_days_to_clear')} days") | |
| #except Exception as e: | |
| # logger.error(f"⚠️ Prediction failed: {e}") | |
| update_job_status(job_id, 'completed') | |
| logger.info(f"🎉 Job {job_id} completed successfully") | |
| except Exception as e: | |
| logger.error(f"❌ Job {job_id} failed: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| update_job_status(job_id, 'failed', str(e)) | |
| def process_document_agent(job_id: str, doc_id: str, file_path: Path, user_message: Optional[str] = None): | |
| """ | |
| NEW AUTONOMOUS AGENT PIPELINE with optional wrapper | |
| """ | |
| try: | |
| # Clean up user_message | |
| if user_message in [None, 'None', '', 'null', 'undefined']: | |
| user_message = None | |
| else: | |
| user_message = str(user_message).strip() | |
| if not user_message: | |
| user_message = None | |
| logger.info("=" * 70) | |
| logger.info(f"🔍 AGENT - Processing with message: '{user_message}'") | |
| logger.info(f"🔍 Type: {type(user_message)}") | |
| logger.info(f"🔍 Is None: {user_message is None}") | |
| logger.info("=" * 70) | |
| from backend.app.agent.agent_orchestrator import ( | |
| InvoiceAgent, AgentState, create_agent | |
| ) | |
| logger.info("=" * 70) | |
| logger.info(f"🤖 AUTONOMOUS AGENT MODE for {file_path.name}") | |
| logger.info("=" * 70) | |
| update_job_status(job_id, 'processing') | |
| # Create agent | |
| agent = create_agent( | |
| call_text_extractor, | |
| call_table_extractor, | |
| call_ner, | |
| map_with_gemini | |
| ) | |
| # Initialize state | |
| state = AgentState(doc_id=doc_id, file_path=file_path) | |
| # Let agent autonomously decide and execute | |
| result_state = agent.process(state) | |
| # ============================================ | |
| # WRAPPER INTEGRATION | |
| # ============================================ | |
| full_extraction = result_state.fields | |
| final_result = full_extraction | |
| wrapper_used = False | |
| # Check if user_message is actually provided | |
| if user_message is not None and len(user_message) > 0: | |
| logger.info("=" * 70) | |
| logger.info(f"💬 USER MESSAGE DETECTED: '{user_message}'") | |
| logger.info("🎯 Activating Gemini wrapper to filter output...") | |
| logger.info(f"📦 Full extraction fields: {list(full_extraction.keys())}") | |
| logger.info("=" * 70) | |
| try: | |
| from backend.app.wrappers.gemini_output_filter import GeminiOutputFilter | |
| wrapper = GeminiOutputFilter() | |
| final_result = wrapper.filter_output(user_message, full_extraction) | |
| wrapper_used = True | |
| logger.info("=" * 70) | |
| logger.info(f"✅ WRAPPER SUCCESS!") | |
| logger.info(f"📤 Original fields: {list(full_extraction.keys())}") | |
| logger.info(f"🎯 Filtered fields: {list(final_result.keys())}") | |
| logger.info(f"📋 Filtered result: {json.dumps(final_result, indent=2)}") | |
| logger.info("=" * 70) | |
| except Exception as wrapper_error: | |
| logger.error("=" * 70) | |
| logger.error(f"❌ WRAPPER FAILED: {wrapper_error}") | |
| logger.error("=" * 70) | |
| import traceback | |
| logger.error(traceback.format_exc()) | |
| logger.warning("📦 Falling back to full extraction") | |
| final_result = full_extraction | |
| wrapper_used = False | |
| else: | |
| logger.info("=" * 70) | |
| logger.info("ℹ️ No user message provided - returning full extraction") | |
| logger.info(f"📦 Full extraction fields: {list(full_extraction.keys())}") | |
| logger.info("=" * 70) | |
| # ============================================ | |
| # Save results | |
| # ============================================ | |
| if result_state.fields: | |
| # Determine method | |
| if 'use_gemini' in result_state.history: | |
| method = 'autonomous_agent_gemini' | |
| elif 'use_regex' in result_state.history: | |
| method = 'autonomous_agent_regex' | |
| else: | |
| method = 'autonomous_agent' | |
| if wrapper_used: | |
| method += '_with_wrapper' | |
| save_extraction( | |
| doc_id, | |
| result_state.raw_text or '', | |
| result_state.tables or [], | |
| result_state.entities or [], | |
| { | |
| 'method': method, | |
| 'attempts': result_state.attempts, | |
| 'actions': result_state.history, | |
| 'confidence': agent._calculate_overall_confidence(result_state), | |
| 'errors': result_state.errors, | |
| 'user_message': user_message, | |
| 'wrapper_used': wrapper_used, | |
| 'full_extraction_keys': list(full_extraction.keys()) if full_extraction else [], | |
| 'filtered_keys': list(final_result.keys()) if wrapper_used else None | |
| }, | |
| None | |
| ) | |
| # Save filtered result | |
| save_invoice_fields( | |
| doc_id, | |
| final_result, | |
| result_state.confidence_map or {} | |
| ) | |
| # Call prediction | |
| logger.info("🔮 Calling payment prediction...") | |
| try: | |
| pred_response = httpx.post(PREDICT_ENDPOINT, json=final_result, timeout=30) | |
| if pred_response.status_code == 200: | |
| pred_result = pred_response.json() | |
| logger.info(f"✅ Prediction: {pred_result.get('predicted_days_to_clear')} days") | |
| except Exception as e: | |
| logger.error(f"⚠️ Prediction failed: {e}") | |
| # Check status | |
| from backend.app.agent.agent_orchestrator import AgentDecision | |
| if AgentDecision.HUMAN_REVIEW.value in result_state.history: | |
| update_job_status(job_id, 'needs_review') | |
| logger.info("👤 Agent requesting human review") | |
| else: | |
| update_job_status(job_id, 'completed') | |
| logger.info(f"✅ Agent completed with confidence: {agent._calculate_overall_confidence(result_state):.2f}") | |
| else: | |
| update_job_status(job_id, 'failed', 'Agent could not extract fields') | |
| logger.error("❌ Agent failed to extract any fields") | |
| except ImportError as e: | |
| logger.error(f"❌ Agent module not found: {e}") | |
| logger.info("⚠️ Falling back to legacy pipeline...") | |
| process_document_legacy(job_id, doc_id, file_path) | |
| except Exception as e: | |
| logger.error(f"❌ Agent failed: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| update_job_status(job_id, 'failed', str(e)) | |
| def process_document(job_id: str, doc_id: str, file_path: Path, user_message: Optional[str] = None): | |
| """ | |
| Main entry point - routes to agent or legacy pipeline. | |
| """ | |
| # Clean up user_message | |
| if user_message in [None, 'None', '', 'null', 'undefined']: | |
| user_message = None | |
| else: | |
| user_message = str(user_message).strip() | |
| if not user_message: | |
| user_message = None | |
| logger.info("=" * 70) | |
| logger.info(f"🔍 PROCESS_DOCUMENT - Cleaned user_message: '{user_message}'") | |
| logger.info(f"🔍 Type: {type(user_message)}") | |
| logger.info(f"🔍 Is None: {user_message is None}") | |
| logger.info("=" * 70) | |
| if USE_AGENT_MODE: | |
| logger.info("🤖 Using AUTONOMOUS AGENT mode") | |
| process_document_agent(job_id, doc_id, file_path, user_message=user_message) | |
| else: | |
| logger.info("📋 Using LEGACY pipeline mode") | |
| process_document_legacy(job_id, doc_id, file_path) | |
| # ============================================ | |
| # API Endpoints | |
| # ============================================ | |
| class IngestResponse(BaseModel): | |
| job_id: str | |
| doc_id: str | |
| filename: str | |
| status: str | |
| message: str | |
| class JobStatusResponse(BaseModel): | |
| job_id: str | |
| doc_id: str | |
| filename: str | |
| status: str | |
| error_text: Optional[str] = None | |
| created_at: str | |
| updated_at: str | |
| extraction: Optional[Dict] = None | |
| invoice_fields: Optional[Dict] = None | |
| class BatchIngestResponse(BaseModel): | |
| batch_id: str | |
| total_files: int | |
| jobs: List[Dict[str, str]] | |
| message: str | |
| class BatchStatusResponse(BaseModel): | |
| batch_id: str | |
| total_files: int | |
| completed: int | |
| processing: int | |
| failed: int | |
| queued: int | |
| jobs: List[Dict[str, Any]] | |
| async def ingest_document( | |
| background_tasks: BackgroundTasks, | |
| file: UploadFile = File(...), | |
| message: str = Form(None) # CHANGED: Use Form(None) instead of Optional[str] = None | |
| ): | |
| # Clean message parameter | |
| cleaned_message = None | |
| if message and message not in ['None', 'null', '', 'undefined']: | |
| cleaned_message = message.strip() | |
| if not cleaned_message: | |
| cleaned_message = None | |
| logger.info("=" * 70) | |
| logger.info(f"📨 API ENDPOINT - Raw message: '{message}'") | |
| logger.info(f"✨ Cleaned message: '{cleaned_message}'") | |
| logger.info(f"🔍 Message type: {type(cleaned_message)}") | |
| logger.info(f"❓ Is None: {cleaned_message is None}") | |
| logger.info("=" * 70) | |
| try: | |
| allowed_types = ['application/pdf', 'image/png', 'image/jpeg'] | |
| if file.content_type not in allowed_types: | |
| raise HTTPException(400, f"Invalid file type: {file.content_type}") | |
| job_id = f"job_{uuid.uuid4().hex[:12]}" | |
| doc_id = f"doc_{uuid.uuid4().hex[:12]}" | |
| file_ext = file.filename.split('.')[-1] if '.' in file.filename else 'pdf' | |
| stored_filename = f"{doc_id}.{file_ext}" | |
| file_path = STORAGE_PATH / stored_filename | |
| content = await file.read() | |
| with open(file_path, 'wb') as f: | |
| f.write(content) | |
| with FileLock(str(LOCK_PATH), timeout=10): | |
| conn = sqlite3.connect(str(DB_PATH)) | |
| cursor = conn.cursor() | |
| cursor.execute(""" | |
| INSERT INTO ingest_jobs (job_id, doc_id, filename, status) | |
| VALUES (?, ?, ?, 'queued') | |
| """, (job_id, doc_id, file.filename)) | |
| cursor.execute(""" | |
| INSERT INTO documents (doc_id, job_id, path, filename, content_type) | |
| VALUES (?, ?, ?, ?, ?) | |
| """, (doc_id, job_id, str(file_path), file.filename, file.content_type)) | |
| conn.commit() | |
| conn.close() | |
| # Start processing with cleaned message | |
| background_tasks.add_task( | |
| process_document, | |
| job_id, | |
| doc_id, | |
| file_path, | |
| user_message=cleaned_message # Pass cleaned message | |
| ) | |
| logger.info(f"🚀 Background task started with message: '{cleaned_message}'") | |
| mode = "autonomous agent" | |
| if cleaned_message: | |
| mode += f" with intelligent filtering" | |
| logger.info(f"🎯 User wants: '{cleaned_message}'") | |
| return IngestResponse( | |
| job_id=job_id, | |
| doc_id=doc_id, | |
| filename=file.filename, | |
| status='queued', | |
| message=f'Document uploaded. Processing with {mode}.' | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"❌ Ingest endpoint error: {e}") | |
| import traceback | |
| logger.error(traceback.format_exc()) | |
| raise HTTPException(500, str(e)) | |
| def get_ingest_status(job_id: str): | |
| """Get job status with agent decision history (if applicable)""" | |
| try: | |
| with FileLock(str(LOCK_PATH), timeout=10): | |
| conn = sqlite3.connect(str(DB_PATH)) | |
| conn.row_factory = sqlite3.Row | |
| cursor = conn.cursor() | |
| cursor.execute("SELECT * FROM ingest_jobs WHERE job_id = ?", (job_id,)) | |
| job = cursor.fetchone() | |
| if not job: | |
| conn.close() | |
| raise HTTPException(404, "Job not found") | |
| job_data = dict(job) | |
| doc_id = job_data['doc_id'] | |
| if job_data['status'] in ['completed', 'needs_review']: | |
| cursor.execute("SELECT * FROM extractions WHERE doc_id = ?", (doc_id,)) | |
| extraction = cursor.fetchone() | |
| if extraction: | |
| ext_dict = dict(extraction) | |
| if ext_dict.get('raw_text'): | |
| ext_dict['raw_text'] = ext_dict['raw_text'][:500] + "..." | |
| job_data['extraction'] = ext_dict | |
| cursor.execute("SELECT * FROM invoice_fields WHERE doc_id = ?", (doc_id,)) | |
| invoice = cursor.fetchone() | |
| if invoice: | |
| inv_dict = dict(invoice) | |
| if inv_dict.get('confidence_map'): | |
| inv_dict['confidence_map'] = json.loads(inv_dict['confidence_map']) | |
| job_data['invoice_fields'] = inv_dict | |
| conn.close() | |
| return JobStatusResponse(**job_data) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"❌ Job status error: {e}") | |
| raise HTTPException(500, str(e)) | |
| async def ingest_batch_documents( | |
| background_tasks: BackgroundTasks, | |
| files: List[UploadFile] = File(...), | |
| message: str = Form(None) | |
| ): | |
| """ | |
| Upload multiple documents for batch processing. | |
| Examples: | |
| 1. Batch upload without filtering: | |
| curl -F "files=@invoice1.jpg" -F "files=@invoice2.pdf" -F "files=@invoice3.png" \ | |
| http://localhost:7860/api/ingest/batch | |
| 2. Batch upload with same extraction rule for all: | |
| curl -F "files=@invoice1.jpg" -F "files=@invoice2.jpg" \ | |
| -F "message=extract only total and date" \ | |
| http://localhost:7860/api/ingest/batch | |
| 3. Maximum 50 files per batch | |
| """ | |
| # Validate batch size | |
| if len(files) > 50: | |
| raise HTTPException(400, "Maximum 50 files per batch") | |
| if len(files) == 0: | |
| raise HTTPException(400, "No files provided") | |
| # Clean message | |
| cleaned_message = None | |
| if message and message not in ['None', 'null', '', 'undefined']: | |
| cleaned_message = message.strip() | |
| if not cleaned_message: | |
| cleaned_message = None | |
| batch_id = f"batch_{uuid.uuid4().hex[:12]}" | |
| jobs = [] | |
| logger.info("=" * 70) | |
| logger.info(f"📦 BATCH UPLOAD - {len(files)} files") | |
| logger.info(f"📦 Batch ID: {batch_id}") | |
| logger.info(f"📦 Message: '{cleaned_message}'") | |
| logger.info("=" * 70) | |
| try: | |
| allowed_types = ['application/pdf', 'image/png', 'image/jpeg'] | |
| for idx, file in enumerate(files): | |
| # Validate each file | |
| if file.content_type not in allowed_types: | |
| logger.warning(f"⚠️ Skipping {file.filename} - invalid type: {file.content_type}") | |
| continue | |
| # Create job for this file | |
| job_id = f"job_{uuid.uuid4().hex[:12]}" | |
| doc_id = f"doc_{uuid.uuid4().hex[:12]}" | |
| file_ext = file.filename.split('.')[-1] if '.' in file.filename else 'pdf' | |
| stored_filename = f"{doc_id}.{file_ext}" | |
| file_path = STORAGE_PATH / stored_filename | |
| # Save file | |
| content = await file.read() | |
| with open(file_path, 'wb') as f: | |
| f.write(content) | |
| # Save to database | |
| with FileLock(str(LOCK_PATH), timeout=10): | |
| conn = sqlite3.connect(str(DB_PATH)) | |
| cursor = conn.cursor() | |
| cursor.execute(""" | |
| INSERT INTO ingest_jobs (job_id, doc_id, filename, status) | |
| VALUES (?, ?, ?, 'queued') | |
| """, (job_id, doc_id, file.filename)) | |
| cursor.execute(""" | |
| INSERT INTO documents (doc_id, job_id, path, filename, content_type) | |
| VALUES (?, ?, ?, ?, ?) | |
| """, (doc_id, job_id, str(file_path), file.filename, file.content_type)) | |
| conn.commit() | |
| conn.close() | |
| # Queue processing | |
| background_tasks.add_task( | |
| process_document, | |
| job_id, | |
| doc_id, | |
| file_path, | |
| user_message=cleaned_message | |
| ) | |
| jobs.append({ | |
| 'job_id': job_id, | |
| 'doc_id': doc_id, | |
| 'filename': file.filename, | |
| 'status': 'queued' | |
| }) | |
| logger.info(f"✅ [{idx+1}/{len(files)}] Queued: {file.filename}") | |
| if not jobs: | |
| raise HTTPException(400, "No valid files to process") | |
| # Save batch metadata | |
| with FileLock(str(LOCK_PATH), timeout=10): | |
| conn = sqlite3.connect(str(DB_PATH)) | |
| cursor = conn.cursor() | |
| # Create batch_jobs table if it doesn't exist | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS batch_jobs ( | |
| batch_id TEXT PRIMARY KEY, | |
| total_files INTEGER, | |
| message TEXT, | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
| ) | |
| """) | |
| cursor.execute(""" | |
| INSERT INTO batch_jobs (batch_id, total_files, message) | |
| VALUES (?, ?, ?) | |
| """, (batch_id, len(jobs), cleaned_message)) | |
| # Link jobs to batch | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS batch_job_mapping ( | |
| batch_id TEXT, | |
| job_id TEXT, | |
| FOREIGN KEY (job_id) REFERENCES ingest_jobs(job_id) | |
| ) | |
| """) | |
| for job in jobs: | |
| cursor.execute(""" | |
| INSERT INTO batch_job_mapping (batch_id, job_id) | |
| VALUES (?, ?) | |
| """, (batch_id, job['job_id'])) | |
| conn.commit() | |
| conn.close() | |
| mode = "autonomous agent" | |
| if cleaned_message: | |
| mode += " with intelligent filtering" | |
| logger.info(f"🚀 Batch {batch_id} processing started with {len(jobs)} files") | |
| return BatchIngestResponse( | |
| batch_id=batch_id, | |
| total_files=len(jobs), | |
| jobs=jobs, | |
| message=f'Batch of {len(jobs)} documents uploaded. Processing with {mode}.' | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"❌ Batch ingest error: {e}") | |
| import traceback | |
| logger.error(traceback.format_exc()) | |
| raise HTTPException(500, str(e)) | |
| def get_batch_status(batch_id: str): | |
| """ | |
| Get status of all jobs in a batch. | |
| Example: | |
| curl http://localhost:7860/api/ingest/batch/batch_abc123 | |
| """ | |
| try: | |
| with FileLock(str(LOCK_PATH), timeout=10): | |
| conn = sqlite3.connect(str(DB_PATH)) | |
| conn.row_factory = sqlite3.Row | |
| cursor = conn.cursor() | |
| # Get batch info | |
| cursor.execute("SELECT * FROM batch_jobs WHERE batch_id = ?", (batch_id,)) | |
| batch = cursor.fetchone() | |
| if not batch: | |
| conn.close() | |
| raise HTTPException(404, "Batch not found") | |
| # Get all jobs in batch | |
| cursor.execute(""" | |
| SELECT j.* FROM ingest_jobs j | |
| JOIN batch_job_mapping bm ON j.job_id = bm.job_id | |
| WHERE bm.batch_id = ? | |
| """, (batch_id,)) | |
| jobs = cursor.fetchall() | |
| conn.close() | |
| # Count statuses | |
| status_counts = { | |
| 'completed': 0, | |
| 'processing': 0, | |
| 'failed': 0, | |
| 'queued': 0, | |
| 'needs_review': 0 | |
| } | |
| jobs_list = [] | |
| for job in jobs: | |
| job_dict = dict(job) | |
| status = job_dict['status'] | |
| status_counts[status] = status_counts.get(status, 0) + 1 | |
| jobs_list.append({ | |
| 'job_id': job_dict['job_id'], | |
| 'doc_id': job_dict['doc_id'], | |
| 'filename': job_dict['filename'], | |
| 'status': status, | |
| 'error_text': job_dict.get('error_text'), | |
| 'created_at': job_dict['created_at'], | |
| 'updated_at': job_dict['updated_at'] | |
| }) | |
| return BatchStatusResponse( | |
| batch_id=batch_id, | |
| total_files=len(jobs), | |
| completed=status_counts['completed'], | |
| processing=status_counts['processing'], | |
| failed=status_counts['failed'], | |
| queued=status_counts['queued'], | |
| jobs=jobs_list | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"❌ Batch status error: {e}") | |
| raise HTTPException(500, str(e)) | |
| def download_batch_results(batch_id: str): | |
| """ | |
| Download all extracted data from a batch as CSV. | |
| Example: | |
| curl http://localhost:7860/api/ingest/batch/batch_abc123/download -o results.csv | |
| """ | |
| try: | |
| import csv | |
| from io import StringIO | |
| from fastapi.responses import StreamingResponse | |
| with FileLock(str(LOCK_PATH), timeout=10): | |
| conn = sqlite3.connect(str(DB_PATH)) | |
| conn.row_factory = sqlite3.Row | |
| cursor = conn.cursor() | |
| # Get all completed jobs in batch | |
| cursor.execute(""" | |
| SELECT j.*, f.* FROM ingest_jobs j | |
| JOIN batch_job_mapping bm ON j.job_id = bm.job_id | |
| LEFT JOIN invoice_fields f ON j.doc_id = f.doc_id | |
| WHERE bm.batch_id = ? AND j.status = 'completed' | |
| """, (batch_id,)) | |
| results = cursor.fetchall() | |
| conn.close() | |
| if not results: | |
| raise HTTPException(404, "No completed jobs found in batch") | |
| # Create CSV | |
| output = StringIO() | |
| writer = csv.writer(output) | |
| # Header | |
| writer.writerow([ | |
| 'filename', 'doc_id', 'customer', 'date', 'amount', | |
| 'payment_terms', 'business_code', 'status' | |
| ]) | |
| # Data rows | |
| for row in results: | |
| writer.writerow([ | |
| row['filename'], | |
| row['doc_id'], | |
| row['cust_number'] or 'N/A', | |
| row['posting_date'] or 'N/A', | |
| row['total_open_amount'] or 0.0, | |
| row['cust_payment_terms'] or 'N/A', | |
| row['business_code'] or 'N/A', | |
| row['status'] | |
| ]) | |
| output.seek(0) | |
| return StreamingResponse( | |
| iter([output.getvalue()]), | |
| media_type="text/csv", | |
| headers={ | |
| "Content-Disposition": f"attachment; filename=batch_{batch_id}_results.csv" | |
| } | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| raise HTTPException(500, str(e)) |