dipan004's picture
Update backend/app/api/ingest.py
aef305f verified
import os
import uuid
import json
import sqlite3
import logging
import csv
from pathlib import Path
from datetime import datetime
from typing import Optional, Dict, List, Any
from fastapi import APIRouter, UploadFile, File, HTTPException, BackgroundTasks, Form
from pydantic import BaseModel
from filelock import FileLock
import httpx
import re
import sys
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler(sys.stdout)],
force=True
)
logger = logging.getLogger(__name__)
BASE_DIR = Path(__file__).parent.parent.parent.parent
STORAGE_PATH = Path(os.getenv('STORAGE_PATH', str(BASE_DIR / "data" / "docs")))
DB_PATH = Path(os.getenv('DB_PATH', str(BASE_DIR / "data" / "invoices.db")))
LOCK_PATH = BASE_DIR / "data" / "invoices.db.lock"
PREDICT_ENDPOINT = 'http://localhost:7860/predict'
STORAGE_PATH.mkdir(parents=True, exist_ok=True)
router = APIRouter(prefix="/api", tags=["ingest"])
def _init_db_tables():
"""Create tables on module import - ensures HF Space has tables"""
try:
logger.info("🔍 Checking if database tables exist...")
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
with FileLock(str(LOCK_PATH), timeout=10):
conn = sqlite3.connect(str(DB_PATH))
cursor = conn.cursor()
# Quick check
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='ingest_jobs'")
if cursor.fetchone():
conn.close()
logger.info("✅ Database tables already exist")
return
logger.warning("⚠️ Database tables not found, creating...")
# Create all tables
tables_sql = [
"""CREATE TABLE IF NOT EXISTS ingest_jobs (
job_id TEXT PRIMARY KEY,
doc_id TEXT,
filename TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'queued',
error_text TEXT,
created_at TEXT DEFAULT CURRENT_TIMESTAMP,
updated_at TEXT DEFAULT CURRENT_TIMESTAMP
)""",
"""CREATE TABLE IF NOT EXISTS documents (
doc_id TEXT PRIMARY KEY,
job_id TEXT NOT NULL,
path TEXT NOT NULL,
filename TEXT NOT NULL,
content_type TEXT NOT NULL,
uploaded_at TEXT DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (job_id) REFERENCES ingest_jobs(job_id)
)""",
"""CREATE TABLE IF NOT EXISTS extractions (
doc_id TEXT PRIMARY KEY,
raw_text TEXT,
tables_json TEXT,
entities_json TEXT,
classification_json TEXT,
summary_text TEXT,
extracted_at TEXT DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (doc_id) REFERENCES documents(doc_id)
)""",
"""CREATE TABLE IF NOT EXISTS invoice_fields (
invoice_id INTEGER PRIMARY KEY AUTOINCREMENT,
doc_id TEXT NOT NULL,
cust_number TEXT,
posting_date TEXT,
total_open_amount REAL,
business_code TEXT,
cust_payment_terms TEXT,
confidence_map TEXT,
created_at TEXT DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (doc_id) REFERENCES documents(doc_id)
)""",
"""CREATE TABLE IF NOT EXISTS batch_jobs (
batch_id TEXT PRIMARY KEY,
total_files INTEGER,
message TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)""",
"""CREATE TABLE IF NOT EXISTS batch_job_mapping (
batch_id TEXT,
job_id TEXT,
FOREIGN KEY (batch_id) REFERENCES batch_jobs(batch_id),
FOREIGN KEY (job_id) REFERENCES ingest_jobs(job_id)
)""",
# ML feature tables
"""CREATE TABLE IF NOT EXISTS customer_aggregates (
cust_number TEXT PRIMARY KEY,
cust_avg_days REAL,
cust_median_days REAL,
cust_invoice_count INTEGER,
last_updated TEXT DEFAULT CURRENT_TIMESTAMP
)""",
"""CREATE TABLE IF NOT EXISTS payment_terms_aggregates (
cust_payment_terms TEXT PRIMARY KEY,
terms_avg_days REAL,
terms_median_days REAL,
terms_invoice_count INTEGER,
last_updated TEXT DEFAULT CURRENT_TIMESTAMP
)""",
"""CREATE TABLE IF NOT EXISTS business_code_aggregates (
business_code TEXT PRIMARY KEY,
bc_avg_days REAL,
bc_median_days REAL,
bc_invoice_count INTEGER,
last_updated TEXT DEFAULT CURRENT_TIMESTAMP
)""",
"""CREATE TABLE IF NOT EXISTS predictions_log (
prediction_id INTEGER PRIMARY KEY AUTOINCREMENT,
invoice_id INTEGER,
cust_number TEXT,
posting_date TEXT,
total_open_amount REAL,
business_code TEXT,
cust_payment_terms TEXT,
predicted_days_to_clear REAL,
predicted_clear_date TEXT,
model_version TEXT,
features_json TEXT,
predicted_at TEXT DEFAULT CURRENT_TIMESTAMP
)"""
]
# Execute all CREATE TABLE statements
for sql in tables_sql:
cursor.execute(sql)
# Create indexes
indexes_sql = [
"CREATE INDEX IF NOT EXISTS idx_ingest_jobs_status ON ingest_jobs(status)",
"CREATE INDEX IF NOT EXISTS idx_ingest_jobs_created ON ingest_jobs(created_at DESC)",
"CREATE INDEX IF NOT EXISTS idx_documents_job_id ON documents(job_id)",
"CREATE INDEX IF NOT EXISTS idx_invoice_fields_doc_id ON invoice_fields(doc_id)",
"CREATE INDEX IF NOT EXISTS idx_batch_mapping_batch ON batch_job_mapping(batch_id)",
"CREATE INDEX IF NOT EXISTS idx_predictions_log_cust ON predictions_log(cust_number)"
]
for sql in indexes_sql:
cursor.execute(sql)
conn.commit()
conn.close()
logger.info("✅ Database tables created successfully!")
except Exception as e:
logger.error(f"❌ Failed to create tables: {e}")
import traceback
logger.error(traceback.format_exc())
# Run on module import
logger.info("🔍 Initializing database on module load...")
try:
_init_db_tables()
logger.info("✅ Database initialization complete")
except Exception as e:
logger.error(f"❌ Database initialization failed: {e}")
logger.warning("⚠️ Application may not work correctly!")
# ============================================
# LOCAL OCR FALLBACK (NEW)
# ============================================
# ============================================
# LOCAL OCR FALLBACK (UPDATED - EasyOCR + Tesseract)
# ============================================
def extract_text_with_easyocr(file_path: Path) -> tuple:
"""
EasyOCR - Best free open-source OCR
- Works offline
- 80+ languages
- GPU/CPU support
- Better accuracy than Tesseract for invoices
"""
try:
import easyocr
logger.info("🔧 Using EasyOCR (best free OCR)...")
# Initialize reader (downloads models on first run)
# Use GPU if available, fallback to CPU
reader = easyocr.Reader(['en'], gpu=False) # Set gpu=True if you have CUDA
# Read image
result = reader.readtext(str(file_path), detail=0, paragraph=True)
# Join all text
text = '\n'.join(result)
if text and len(text.strip()) >= 10:
logger.info(f"✅ EasyOCR extracted {len(text)} characters")
return True, text, None
return False, None, "EasyOCR produced no usable text"
except ImportError:
logger.warning("⚠️ easyocr not installed. Install with: pip install easyocr")
return False, None, "easyocr not available"
except Exception as e:
logger.error(f"❌ EasyOCR failed: {e}")
return False, None, str(e)
def extract_text_with_tesseract(file_path: Path) -> tuple:
"""
Tesseract OCR - Fallback option
Faster but less accurate than EasyOCR
"""
try:
import pytesseract
from PIL import Image
logger.info("🔧 Using Tesseract OCR as secondary fallback...")
image = Image.open(file_path)
text = pytesseract.image_to_string(image)
if text and len(text.strip()) >= 10:
logger.info(f"✅ Tesseract extracted {len(text)} characters")
return True, text, None
return False, None, "Tesseract produced no usable text"
except ImportError:
logger.warning("⚠️ pytesseract not installed. Install with: pip install pytesseract pillow")
return False, None, "pytesseract not available"
except Exception as e:
logger.error(f"❌ Tesseract failed: {e}")
return False, None, str(e)
def extract_text_with_local_ocr(file_path: Path) -> tuple:
"""
Multi-tier local OCR fallback system:
1. Try EasyOCR (best accuracy)
2. Try Tesseract (faster, less accurate)
3. Give up
"""
logger.info("=" * 70)
logger.info("🔄 HF extraction failed - trying local OCR fallbacks...")
logger.info("=" * 70)
# Priority 1: EasyOCR (best for invoices)
success, text, error = extract_text_with_easyocr(file_path)
if success:
logger.info("✅ EasyOCR succeeded!")
return True, text, None
else:
logger.warning(f"⚠️ EasyOCR failed: {error}")
# Priority 2: Tesseract (faster fallback)
success, text, error = extract_text_with_tesseract(file_path)
if success:
logger.info("✅ Tesseract succeeded!")
return True, text, None
else:
logger.warning(f"⚠️ Tesseract failed: {error}")
# All local OCR failed
logger.error("❌ All local OCR methods failed")
return False, None, "All local OCR methods failed"
# ============================================
# STEP 1: HF Agent Text Extraction (UPDATED)
# ============================================
def get_agent_headers():
"""Get headers with HF token"""
token = (
os.getenv('HF_TOKEN') or
os.getenv('HUGGINGFACE_API_TOKEN') or
os.getenv('AGENT_BEARER_TOKEN') or
''
)
return {'Authorization': f'Bearer {token}'} if token else {}
def get_mime_type(file_path: Path) -> str:
"""Get MIME type"""
ext = file_path.suffix.lower()
mime_map = {
'.pdf': 'application/pdf',
'.jpg': 'image/jpeg',
'.jpeg': 'image/jpeg',
'.png': 'image/png'
}
return mime_map.get(ext, 'application/octet-stream')
def call_text_extractor(file_path: Path, max_retries=3):
"""
HF text extraction with retry logic and exponential backoff.
Falls back to local OCR if all retries fail.
"""
url = os.getenv('TEXT_EXTRACTOR_URL', 'https://point9-extract-text-and-table.hf.space/api/text')
base_timeout = int(os.getenv('AGENT_TIMEOUT_SECONDS', '120'))
for attempt in range(max_retries):
# Progressive timeout: 120s, 180s, 240s
timeout = base_timeout + (60 * attempt)
try:
logger.info(f"📄 Extracting text from {file_path.name} (attempt {attempt + 1}/{max_retries}, timeout={timeout}s)...")
filename = file_path.name
mime_type = get_mime_type(file_path)
with open(file_path, 'rb') as f:
files = {'file': (filename, f, mime_type)}
data = {
'filename': filename,
'start_page': 1,
'end_page': 1
}
headers = get_agent_headers()
response = httpx.post(url, files=files, data=data, headers=headers, timeout=timeout)
if response.status_code == 200:
result = response.json()
text = result.get('result') or result.get('text') or result.get('extracted_text') or ''
if text and len(text.strip()) >= 10:
logger.info(f"✅ Extracted {len(text)} characters")
return True, text, None
logger.warning("⚠️ No text extracted from response")
if attempt < max_retries - 1:
continue
return False, None, "No text extracted"
logger.warning(f"⚠️ HTTP {response.status_code}: {response.text[:200]}")
except httpx.TimeoutException:
logger.warning(f"⚠️ Timeout after {timeout}s on attempt {attempt + 1}")
if attempt < max_retries - 1:
logger.info("🔄 Retrying with longer timeout...")
continue
except Exception as e:
logger.error(f"❌ Error on attempt {attempt + 1}: {e}")
if attempt < max_retries - 1:
logger.info("🔄 Retrying...")
continue
# All retries failed - try local OCR fallback
logger.warning(f"⚠️ All {max_retries} HF extraction attempts failed, trying local OCR fallback...")
return extract_text_with_local_ocr(file_path)
def call_table_extractor(file_path: Path, max_retries=2):
"""
HF table extraction with retry logic.
Non-critical, so fewer retries.
"""
url = os.getenv('TABLE_EXTRACTOR_URL', 'https://point9-extract-text-and-table.hf.space/api/tables')
base_timeout = int(os.getenv('AGENT_TIMEOUT_SECONDS', '120'))
for attempt in range(max_retries):
timeout = base_timeout + (60 * attempt)
try:
logger.info(f"📊 Extracting tables from {file_path.name} (attempt {attempt + 1}/{max_retries})...")
filename = file_path.name
mime_type = get_mime_type(file_path)
with open(file_path, 'rb') as f:
files = {'file': (filename, f, mime_type)}
data = {
'filename': filename,
'start_page': 1,
'end_page': 1
}
headers = get_agent_headers()
response = httpx.post(url, files=files, data=data, headers=headers, timeout=timeout)
if response.status_code == 200:
result = response.json()
tables = result.get('result') or result.get('tables') or []
logger.info(f"✅ Extracted {len(tables)} tables")
return True, tables, None
logger.warning(f"⚠️ HTTP {response.status_code}")
except httpx.TimeoutException:
logger.warning(f"⚠️ Table extraction timeout on attempt {attempt + 1}")
except Exception as e:
logger.warning(f"⚠️ Table extraction error: {e}")
# Non-critical - return empty list
logger.info("ℹ️ Table extraction failed, continuing without tables")
return False, [], "Table extraction failed (non-critical)"
# ============================================
# STEP 2: HF NER (Named Entity Recognition)
# ============================================
def call_ner(text: str, file_path: Path = None, max_retries=2) -> tuple:
"""
Extract named entities using HF NER agent with retry logic.
"""
url = os.getenv('NER_URL', 'https://point9-ner.hf.space/api/ner')
base_timeout = int(os.getenv('AGENT_TIMEOUT_SECONDS', '120'))
for attempt in range(max_retries):
timeout = base_timeout + (30 * attempt)
try:
logger.info(f"🔍 Running NER to find entities (attempt {attempt + 1}/{max_retries})...")
headers = get_agent_headers()
# NER expects multipart/form-data with file OR text
if file_path and file_path.exists():
# Send file
filename = file_path.name
mime_type = get_mime_type(file_path)
with open(file_path, 'rb') as f:
files = {'file': (filename, f, mime_type)}
data = {
'text': text[:5000],
'filename': filename,
'start_page': 1,
'end_page': 1
}
response = httpx.post(url, files=files, data=data, headers=headers, timeout=timeout)
else:
# Send just text as form data
data = {
'text': text[:5000],
'filename': 'document.txt',
'start_page': 1,
'end_page': 1
}
response = httpx.post(url, data=data, headers=headers, timeout=timeout)
if response.status_code == 200:
result = response.json()
# FIX: Handle both dict and string responses
if isinstance(result, str):
try:
result = json.loads(result)
except:
logger.warning(f"⚠️ NER returned unparseable string: {result[:100]}")
if attempt < max_retries - 1:
continue
return False, [], {}, "Invalid response format"
# Extract entities
entities = result.get('entities') or result.get('result') or []
# Handle case where entities might also be a string
if isinstance(entities, str):
try:
entities = json.loads(entities)
except:
entities = []
logger.info(f"✅ Found {len(entities)} entities")
# Group entities by type
entity_map = {
'PERSON': [],
'ORG': [],
'DATE': [],
'MONEY': [],
'CARDINAL': []
}
for entity in entities:
if not isinstance(entity, dict):
continue
ent_type = entity.get('entity_type') or entity.get('label')
ent_text = entity.get('text') or entity.get('word')
if ent_type in entity_map and ent_text:
entity_map[ent_type].append(ent_text)
logger.info(f"📋 Entity summary: PERSON={len(entity_map['PERSON'])}, ORG={len(entity_map['ORG'])}, DATE={len(entity_map['DATE'])}, MONEY={len(entity_map['MONEY'])}")
return True, entities, entity_map, None
logger.warning(f"⚠️ NER HTTP {response.status_code}")
except httpx.TimeoutException:
logger.warning(f"⚠️ NER timeout on attempt {attempt + 1}")
except Exception as e:
logger.error(f"❌ NER error on attempt {attempt + 1}: {e}")
# NER failed - return empty (non-critical)
logger.warning("⚠️ NER failed after retries, continuing without entities")
return False, [], {}, "NER failed (non-critical)"
# ============================================
# STEP 3: Gemini Intelligent Mapping
# ============================================
def map_with_gemini(text: str, entities: List, entity_map: Dict, tables: List):
"""Use Gemini to intelligently map extracted data to invoice fields"""
try:
import google.generativeai as genai
api_key = os.getenv('GEMINI_API_KEY')
if not api_key:
logger.warning("⚠️ No Gemini API key configured")
return False, None, "No Gemini API key"
logger.info("🧠 Using Gemini for intelligent field mapping...")
genai.configure(api_key=api_key)
model = genai.GenerativeModel('models/gemini-2.5-flash')
# Build context for Gemini
context = f"""
EXTRACTED TEXT:
{text[:3000]}
NAMED ENTITIES FOUND:
- Organizations: {entity_map.get('ORG', [])}
- People: {entity_map.get('PERSON', [])}
- Dates: {entity_map.get('DATE', [])}
- Money amounts: {entity_map.get('MONEY', [])}
- Numbers: {entity_map.get('CARDINAL', [])}
TABLES:
{json.dumps(tables[:2], indent=2) if tables else 'None'}
"""
prompt = f"""You are an expert at analyzing invoice data. Given the extracted text and entities below, map them to invoice fields.
{context}
Analyze the above data and return ONLY a valid JSON object with these exact fields:
{{
"customer_name": "the client/customer company name (check ORG entities first)",
"invoice_number": "the invoice number (check CARDINAL entities)",
"date": "invoice date in YYYY-MM-DD format (check DATE entities)",
"total_amount": numeric total amount only (check MONEY entities, no currency symbol),
"payment_terms": "payment terms like NET30, NET60, or NAH4 if not found",
"reasoning": "brief explanation of how you identified each field"
}}
Rules:
1. Prefer entities over raw text when available
2. Customer name is usually the first ORG after "Bill To" or "Client"
3. Total amount is usually the largest MONEY value
4. Date should be in YYYY-MM-DD format
5. If uncertain, use these defaults: customer_name="UNKNOWN", date="2024-01-01", total_amount=0.0, payment_terms="NAH4"
Return ONLY the JSON object, no markdown, no explanation outside the JSON."""
response = model.generate_content(prompt)
text_response = response.text.strip()
# Remove markdown if present
text_response = text_response.replace('```json', '').replace('```', '').strip()
result = json.loads(text_response)
logger.info(f"✅ Gemini mapped: Customer={result.get('customer_name')}, Amount=${result.get('total_amount')}")
logger.info(f"💡 Reasoning: {result.get('reasoning', 'N/A')[:100]}")
return True, result, None
except json.JSONDecodeError as e:
logger.error(f"❌ Gemini returned invalid JSON: {e}")
logger.error(f"Response: {text_response[:500]}")
return False, None, f"Invalid JSON: {e}"
except Exception as e:
logger.error(f"❌ Gemini mapping failed: {e}")
import traceback
logger.error(traceback.format_exc())
return False, None, str(e)
# ============================================
# Fallback: Regex Mapping
# ============================================
def map_with_regex(text: str, entities: List) -> tuple:
"""Fallback regex-based field extraction"""
logger.info("🔤 Using regex fallback for field mapping...")
fields = {}
confidence = {}
# CUSTOMER NAME - try to use ORG entities first
org_entities = [e.get('text') or e.get('word') for e in entities
if (e.get('entity_type') or e.get('label')) == 'ORG']
if org_entities:
fields['cust_number'] = org_entities[0][:20]
confidence['cust_number'] = 0.8
else:
# Regex fallback
client_patterns = [
r'(?:Client|Bill\s+To|Customer)[:\s]+(.*?)(?:\n|Tax|IBAN)',
r'(?:customer|client)[\s:]+([A-Za-z][A-Za-z\s,&-]+?)(?:\n|$)',
]
for pattern in client_patterns:
match = re.search(pattern, text, re.IGNORECASE | re.MULTILINE)
if match:
client = match.group(1).strip()
words = [w.strip() for w in client.replace(',', ' ').split() if len(w.strip()) > 2]
if words:
fields['cust_number'] = words[0][:20]
confidence['cust_number'] = 0.6
break
if 'cust_number' not in fields:
fields['cust_number'] = 'UNKNOWN'
confidence['cust_number'] = 0.1
# DATE - try DATE entities first
date_entities = [e.get('text') or e.get('word') for e in entities
if (e.get('entity_type') or e.get('label')) == 'DATE']
if date_entities:
date_str = date_entities[0]
for fmt in ['%m/%d/%Y', '%d/%m/%Y', '%Y-%m-%d', '%m-%d-%Y']:
try:
dt = datetime.strptime(date_str, fmt)
fields['posting_date'] = dt.strftime('%Y-%m-%d')
confidence['posting_date'] = 0.8
break
except:
continue
if 'posting_date' not in fields:
date_patterns = [
r'(?:Date\s+of\s+issue|Invoice\s+Date|Date)[:\s]+(\d{1,2}[/-]\d{1,2}[/-]\d{2,4})',
]
for pattern in date_patterns:
match = re.search(pattern, text, re.IGNORECASE)
if match:
date_str = match.group(1)
for fmt in ['%m/%d/%Y', '%d/%m/%Y']:
try:
dt = datetime.strptime(date_str, fmt)
fields['posting_date'] = dt.strftime('%Y-%m-%d')
confidence['posting_date'] = 0.7
break
except:
continue
if 'posting_date' in fields:
break
if 'posting_date' not in fields:
fields['posting_date'] = datetime.now().strftime('%Y-%m-%d')
confidence['posting_date'] = 0.1
# AMOUNT - try MONEY entities first
money_entities = [e.get('text') or e.get('word') for e in entities
if (e.get('entity_type') or e.get('label')) == 'MONEY']
if money_entities:
amounts = []
for money_str in money_entities:
try:
# Remove currency symbols and parse
amt_str = re.sub(r'[^\d.]', '', money_str)
amt = float(amt_str)
if amt > 10:
amounts.append(amt)
except:
pass
if amounts:
fields['total_open_amount'] = max(amounts)
confidence['total_open_amount'] = 0.8
logger.info(f"✅ Found amount from MONEY entity: ${fields['total_open_amount']}")
if 'total_open_amount' not in fields:
# Regex fallback
pattern = r'\$\s*([0-9]{1,3}(?:,?[0-9]{3})*\.[0-9]{2})'
amounts = []
for match in re.finditer(pattern, text):
try:
amt = float(match.group(1).replace(',', ''))
if amt > 50:
amounts.append(amt)
except:
pass
if amounts:
fields['total_open_amount'] = max(amounts)
confidence['total_open_amount'] = 0.6
else:
fields['total_open_amount'] = 0.0
confidence['total_open_amount'] = 0.0
logger.warning("⚠️ No amount found!")
# PAYMENT TERMS
terms_match = re.search(r'(NET\s?\d{1,2}|N\d{2}|NAH\d)', text, re.IGNORECASE)
fields['cust_payment_terms'] = terms_match.group(1).upper() if terms_match else 'NAH4'
confidence['cust_payment_terms'] = 0.7 if terms_match else 0.2
# BUSINESS CODE
fields['business_code'] = 'U001'
confidence['business_code'] = 0.2
return fields, confidence
# ============================================
# Database Functions
# ============================================
def update_job_status(job_id: str, status: str, error_text: str = None):
"""Update job status"""
with FileLock(str(LOCK_PATH), timeout=10):
conn = sqlite3.connect(str(DB_PATH))
cursor = conn.cursor()
cursor.execute("""
UPDATE ingest_jobs
SET status = ?, error_text = ?, updated_at = CURRENT_TIMESTAMP
WHERE job_id = ?
""", (status, error_text, job_id))
conn.commit()
conn.close()
def save_extraction(doc_id: str, raw_text: str, tables: list, entities: list, classification: dict, summary: str = None):
"""Save extraction results"""
with FileLock(str(LOCK_PATH), timeout=10):
conn = sqlite3.connect(str(DB_PATH))
cursor = conn.cursor()
cursor.execute("""
INSERT OR REPLACE INTO extractions (
doc_id, raw_text, tables_json, entities_json,
classification_json, summary_text
) VALUES (?, ?, ?, ?, ?, ?)
""", (
doc_id,
raw_text,
json.dumps(tables) if tables else None,
json.dumps(entities) if entities else None,
json.dumps(classification) if classification else None,
summary
))
conn.commit()
conn.close()
def save_invoice_fields(doc_id: str, fields: Dict, confidence_map: Dict):
"""Save invoice fields"""
with FileLock(str(LOCK_PATH), timeout=10):
conn = sqlite3.connect(str(DB_PATH))
cursor = conn.cursor()
cursor.execute("""
INSERT INTO invoice_fields (
doc_id, cust_number, posting_date, total_open_amount,
business_code, cust_payment_terms, confidence_map
) VALUES (?, ?, ?, ?, ?, ?, ?)
""", (
doc_id,
fields.get('cust_number'),
fields.get('posting_date'),
fields.get('total_open_amount'),
fields.get('business_code'),
fields.get('cust_payment_terms'),
json.dumps(confidence_map)
))
conn.commit()
conn.close()
# ============================================
# AGENT MODE FLAG (Environment Variable)
# ============================================
USE_AGENT_MODE = os.getenv('USE_AGENT_MODE', 'true').lower() == 'true'
# ============================================
# Main Processing Pipeline
# ============================================
def process_document_legacy(job_id: str, doc_id: str, file_path: Path):
"""
LEGACY PIPELINE (Original Implementation):
1. HF Extract text + tables
2. HF NER finds entities
3. Gemini maps to invoice fields
"""
logger.info("=" * 70)
logger.info(f"🚀 Starting LEGACY pipeline for {file_path.name}")
logger.info("=" * 70)
try:
update_job_status(job_id, 'processing')
# STEP 1: Extract text with HF agents
logger.info("STEP 1: HF TEXT + TABLE EXTRACTION")
logger.info("-" * 70)
success, raw_text, error = call_text_extractor(file_path)
if not success or not raw_text:
update_job_status(job_id, 'failed', f"Text extraction failed: {error}")
return
# Extract tables (optional, won't fail if it doesn't work)
_, tables, _ = call_table_extractor(file_path)
# STEP 2: NER to find entities
logger.info("-" * 70)
logger.info("STEP 2: NER - NAMED ENTITY RECOGNITION")
logger.info("-" * 70)
ner_success, entities, entity_map, ner_error = call_ner(raw_text, file_path)
if not ner_success:
logger.warning(f"⚠️ NER failed: {ner_error}, continuing without entities")
entities = []
entity_map = {}
# STEP 3: Gemini intelligent mapping
logger.info("-" * 70)
logger.info("STEP 3: GEMINI INTELLIGENT MAPPING")
logger.info("-" * 70)
gemini_success, gemini_result, gemini_error = map_with_gemini(
raw_text, entities, entity_map, tables
)
if gemini_success and gemini_result:
# Use Gemini's mapping
fields = {
'cust_number': gemini_result.get('customer_name', 'UNKNOWN')[:20],
'posting_date': gemini_result.get('date', datetime.now().strftime('%Y-%m-%d')),
'total_open_amount': float(gemini_result.get('total_amount', 0.0)),
'business_code': 'U001',
'cust_payment_terms': gemini_result.get('payment_terms', 'NAH4')[:10]
}
confidence_map = {
'cust_number': 0.95,
'posting_date': 0.95,
'total_open_amount': 0.95,
'business_code': 0.2,
'cust_payment_terms': 0.8
}
method = 'hf_ner_gemini'
else:
# Fallback to regex mapping
logger.warning(f"⚠️ Gemini mapping failed: {gemini_error}")
logger.info("-" * 70)
logger.info("FALLBACK: REGEX MAPPING")
logger.info("-" * 70)
fields, confidence_map = map_with_regex(raw_text, entities)
method = 'hf_ner_regex'
# Save results
save_extraction(
doc_id, raw_text, tables, entities,
{'method': method, 'entity_count': len(entities)},
None
)
save_invoice_fields(doc_id, fields, confidence_map)
logger.info("=" * 70)
logger.info(f"✅ EXTRACTION COMPLETE - Method: {method}")
logger.info(f"📋 Fields: {fields}")
logger.info("=" * 70)
# Call prediction API
#logger.info("🔮 Calling payment prediction...")
#try:
# pred_response = httpx.post(PREDICT_ENDPOINT, json=fields, timeout=30)
#
# if pred_response.status_code == 200:
# pred_result = pred_response.json()
# logger.info(f"✅ Prediction: {pred_result.get('predicted_days_to_clear')} days")
#except Exception as e:
# logger.error(f"⚠️ Prediction failed: {e}")
update_job_status(job_id, 'completed')
logger.info(f"🎉 Job {job_id} completed successfully")
except Exception as e:
logger.error(f"❌ Job {job_id} failed: {e}")
import traceback
traceback.print_exc()
update_job_status(job_id, 'failed', str(e))
def process_document_agent(job_id: str, doc_id: str, file_path: Path, user_message: Optional[str] = None):
"""
NEW AUTONOMOUS AGENT PIPELINE with optional wrapper
"""
try:
# Clean up user_message
if user_message in [None, 'None', '', 'null', 'undefined']:
user_message = None
else:
user_message = str(user_message).strip()
if not user_message:
user_message = None
logger.info("=" * 70)
logger.info(f"🔍 AGENT - Processing with message: '{user_message}'")
logger.info(f"🔍 Type: {type(user_message)}")
logger.info(f"🔍 Is None: {user_message is None}")
logger.info("=" * 70)
from backend.app.agent.agent_orchestrator import (
InvoiceAgent, AgentState, create_agent
)
logger.info("=" * 70)
logger.info(f"🤖 AUTONOMOUS AGENT MODE for {file_path.name}")
logger.info("=" * 70)
update_job_status(job_id, 'processing')
# Create agent
agent = create_agent(
call_text_extractor,
call_table_extractor,
call_ner,
map_with_gemini
)
# Initialize state
state = AgentState(doc_id=doc_id, file_path=file_path)
# Let agent autonomously decide and execute
result_state = agent.process(state)
# ============================================
# WRAPPER INTEGRATION
# ============================================
full_extraction = result_state.fields
final_result = full_extraction
wrapper_used = False
# Check if user_message is actually provided
if user_message is not None and len(user_message) > 0:
logger.info("=" * 70)
logger.info(f"💬 USER MESSAGE DETECTED: '{user_message}'")
logger.info("🎯 Activating Gemini wrapper to filter output...")
logger.info(f"📦 Full extraction fields: {list(full_extraction.keys())}")
logger.info("=" * 70)
try:
from backend.app.wrappers.gemini_output_filter import GeminiOutputFilter
wrapper = GeminiOutputFilter()
final_result = wrapper.filter_output(user_message, full_extraction)
wrapper_used = True
logger.info("=" * 70)
logger.info(f"✅ WRAPPER SUCCESS!")
logger.info(f"📤 Original fields: {list(full_extraction.keys())}")
logger.info(f"🎯 Filtered fields: {list(final_result.keys())}")
logger.info(f"📋 Filtered result: {json.dumps(final_result, indent=2)}")
logger.info("=" * 70)
except Exception as wrapper_error:
logger.error("=" * 70)
logger.error(f"❌ WRAPPER FAILED: {wrapper_error}")
logger.error("=" * 70)
import traceback
logger.error(traceback.format_exc())
logger.warning("📦 Falling back to full extraction")
final_result = full_extraction
wrapper_used = False
else:
logger.info("=" * 70)
logger.info("ℹ️ No user message provided - returning full extraction")
logger.info(f"📦 Full extraction fields: {list(full_extraction.keys())}")
logger.info("=" * 70)
# ============================================
# Save results
# ============================================
if result_state.fields:
# Determine method
if 'use_gemini' in result_state.history:
method = 'autonomous_agent_gemini'
elif 'use_regex' in result_state.history:
method = 'autonomous_agent_regex'
else:
method = 'autonomous_agent'
if wrapper_used:
method += '_with_wrapper'
save_extraction(
doc_id,
result_state.raw_text or '',
result_state.tables or [],
result_state.entities or [],
{
'method': method,
'attempts': result_state.attempts,
'actions': result_state.history,
'confidence': agent._calculate_overall_confidence(result_state),
'errors': result_state.errors,
'user_message': user_message,
'wrapper_used': wrapper_used,
'full_extraction_keys': list(full_extraction.keys()) if full_extraction else [],
'filtered_keys': list(final_result.keys()) if wrapper_used else None
},
None
)
# Save filtered result
save_invoice_fields(
doc_id,
final_result,
result_state.confidence_map or {}
)
# Call prediction
logger.info("🔮 Calling payment prediction...")
try:
pred_response = httpx.post(PREDICT_ENDPOINT, json=final_result, timeout=30)
if pred_response.status_code == 200:
pred_result = pred_response.json()
logger.info(f"✅ Prediction: {pred_result.get('predicted_days_to_clear')} days")
except Exception as e:
logger.error(f"⚠️ Prediction failed: {e}")
# Check status
from backend.app.agent.agent_orchestrator import AgentDecision
if AgentDecision.HUMAN_REVIEW.value in result_state.history:
update_job_status(job_id, 'needs_review')
logger.info("👤 Agent requesting human review")
else:
update_job_status(job_id, 'completed')
logger.info(f"✅ Agent completed with confidence: {agent._calculate_overall_confidence(result_state):.2f}")
else:
update_job_status(job_id, 'failed', 'Agent could not extract fields')
logger.error("❌ Agent failed to extract any fields")
except ImportError as e:
logger.error(f"❌ Agent module not found: {e}")
logger.info("⚠️ Falling back to legacy pipeline...")
process_document_legacy(job_id, doc_id, file_path)
except Exception as e:
logger.error(f"❌ Agent failed: {e}")
import traceback
traceback.print_exc()
update_job_status(job_id, 'failed', str(e))
def process_document(job_id: str, doc_id: str, file_path: Path, user_message: Optional[str] = None):
"""
Main entry point - routes to agent or legacy pipeline.
"""
# Clean up user_message
if user_message in [None, 'None', '', 'null', 'undefined']:
user_message = None
else:
user_message = str(user_message).strip()
if not user_message:
user_message = None
logger.info("=" * 70)
logger.info(f"🔍 PROCESS_DOCUMENT - Cleaned user_message: '{user_message}'")
logger.info(f"🔍 Type: {type(user_message)}")
logger.info(f"🔍 Is None: {user_message is None}")
logger.info("=" * 70)
if USE_AGENT_MODE:
logger.info("🤖 Using AUTONOMOUS AGENT mode")
process_document_agent(job_id, doc_id, file_path, user_message=user_message)
else:
logger.info("📋 Using LEGACY pipeline mode")
process_document_legacy(job_id, doc_id, file_path)
# ============================================
# API Endpoints
# ============================================
class IngestResponse(BaseModel):
job_id: str
doc_id: str
filename: str
status: str
message: str
class JobStatusResponse(BaseModel):
job_id: str
doc_id: str
filename: str
status: str
error_text: Optional[str] = None
created_at: str
updated_at: str
extraction: Optional[Dict] = None
invoice_fields: Optional[Dict] = None
class BatchIngestResponse(BaseModel):
batch_id: str
total_files: int
jobs: List[Dict[str, str]]
message: str
class BatchStatusResponse(BaseModel):
batch_id: str
total_files: int
completed: int
processing: int
failed: int
queued: int
jobs: List[Dict[str, Any]]
@router.post("/ingest", response_model=IngestResponse)
async def ingest_document(
background_tasks: BackgroundTasks,
file: UploadFile = File(...),
message: str = Form(None) # CHANGED: Use Form(None) instead of Optional[str] = None
):
# Clean message parameter
cleaned_message = None
if message and message not in ['None', 'null', '', 'undefined']:
cleaned_message = message.strip()
if not cleaned_message:
cleaned_message = None
logger.info("=" * 70)
logger.info(f"📨 API ENDPOINT - Raw message: '{message}'")
logger.info(f"✨ Cleaned message: '{cleaned_message}'")
logger.info(f"🔍 Message type: {type(cleaned_message)}")
logger.info(f"❓ Is None: {cleaned_message is None}")
logger.info("=" * 70)
try:
allowed_types = ['application/pdf', 'image/png', 'image/jpeg']
if file.content_type not in allowed_types:
raise HTTPException(400, f"Invalid file type: {file.content_type}")
job_id = f"job_{uuid.uuid4().hex[:12]}"
doc_id = f"doc_{uuid.uuid4().hex[:12]}"
file_ext = file.filename.split('.')[-1] if '.' in file.filename else 'pdf'
stored_filename = f"{doc_id}.{file_ext}"
file_path = STORAGE_PATH / stored_filename
content = await file.read()
with open(file_path, 'wb') as f:
f.write(content)
with FileLock(str(LOCK_PATH), timeout=10):
conn = sqlite3.connect(str(DB_PATH))
cursor = conn.cursor()
cursor.execute("""
INSERT INTO ingest_jobs (job_id, doc_id, filename, status)
VALUES (?, ?, ?, 'queued')
""", (job_id, doc_id, file.filename))
cursor.execute("""
INSERT INTO documents (doc_id, job_id, path, filename, content_type)
VALUES (?, ?, ?, ?, ?)
""", (doc_id, job_id, str(file_path), file.filename, file.content_type))
conn.commit()
conn.close()
# Start processing with cleaned message
background_tasks.add_task(
process_document,
job_id,
doc_id,
file_path,
user_message=cleaned_message # Pass cleaned message
)
logger.info(f"🚀 Background task started with message: '{cleaned_message}'")
mode = "autonomous agent"
if cleaned_message:
mode += f" with intelligent filtering"
logger.info(f"🎯 User wants: '{cleaned_message}'")
return IngestResponse(
job_id=job_id,
doc_id=doc_id,
filename=file.filename,
status='queued',
message=f'Document uploaded. Processing with {mode}.'
)
except HTTPException:
raise
except Exception as e:
logger.error(f"❌ Ingest endpoint error: {e}")
import traceback
logger.error(traceback.format_exc())
raise HTTPException(500, str(e))
@router.get("/ingest/{job_id}", response_model=JobStatusResponse)
def get_ingest_status(job_id: str):
"""Get job status with agent decision history (if applicable)"""
try:
with FileLock(str(LOCK_PATH), timeout=10):
conn = sqlite3.connect(str(DB_PATH))
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
cursor.execute("SELECT * FROM ingest_jobs WHERE job_id = ?", (job_id,))
job = cursor.fetchone()
if not job:
conn.close()
raise HTTPException(404, "Job not found")
job_data = dict(job)
doc_id = job_data['doc_id']
if job_data['status'] in ['completed', 'needs_review']:
cursor.execute("SELECT * FROM extractions WHERE doc_id = ?", (doc_id,))
extraction = cursor.fetchone()
if extraction:
ext_dict = dict(extraction)
if ext_dict.get('raw_text'):
ext_dict['raw_text'] = ext_dict['raw_text'][:500] + "..."
job_data['extraction'] = ext_dict
cursor.execute("SELECT * FROM invoice_fields WHERE doc_id = ?", (doc_id,))
invoice = cursor.fetchone()
if invoice:
inv_dict = dict(invoice)
if inv_dict.get('confidence_map'):
inv_dict['confidence_map'] = json.loads(inv_dict['confidence_map'])
job_data['invoice_fields'] = inv_dict
conn.close()
return JobStatusResponse(**job_data)
except HTTPException:
raise
except Exception as e:
logger.error(f"❌ Job status error: {e}")
raise HTTPException(500, str(e))
@router.post("/ingest/batch", response_model=BatchIngestResponse)
async def ingest_batch_documents(
background_tasks: BackgroundTasks,
files: List[UploadFile] = File(...),
message: str = Form(None)
):
"""
Upload multiple documents for batch processing.
Examples:
1. Batch upload without filtering:
curl -F "files=@invoice1.jpg" -F "files=@invoice2.pdf" -F "files=@invoice3.png" \
http://localhost:7860/api/ingest/batch
2. Batch upload with same extraction rule for all:
curl -F "files=@invoice1.jpg" -F "files=@invoice2.jpg" \
-F "message=extract only total and date" \
http://localhost:7860/api/ingest/batch
3. Maximum 50 files per batch
"""
# Validate batch size
if len(files) > 50:
raise HTTPException(400, "Maximum 50 files per batch")
if len(files) == 0:
raise HTTPException(400, "No files provided")
# Clean message
cleaned_message = None
if message and message not in ['None', 'null', '', 'undefined']:
cleaned_message = message.strip()
if not cleaned_message:
cleaned_message = None
batch_id = f"batch_{uuid.uuid4().hex[:12]}"
jobs = []
logger.info("=" * 70)
logger.info(f"📦 BATCH UPLOAD - {len(files)} files")
logger.info(f"📦 Batch ID: {batch_id}")
logger.info(f"📦 Message: '{cleaned_message}'")
logger.info("=" * 70)
try:
allowed_types = ['application/pdf', 'image/png', 'image/jpeg']
for idx, file in enumerate(files):
# Validate each file
if file.content_type not in allowed_types:
logger.warning(f"⚠️ Skipping {file.filename} - invalid type: {file.content_type}")
continue
# Create job for this file
job_id = f"job_{uuid.uuid4().hex[:12]}"
doc_id = f"doc_{uuid.uuid4().hex[:12]}"
file_ext = file.filename.split('.')[-1] if '.' in file.filename else 'pdf'
stored_filename = f"{doc_id}.{file_ext}"
file_path = STORAGE_PATH / stored_filename
# Save file
content = await file.read()
with open(file_path, 'wb') as f:
f.write(content)
# Save to database
with FileLock(str(LOCK_PATH), timeout=10):
conn = sqlite3.connect(str(DB_PATH))
cursor = conn.cursor()
cursor.execute("""
INSERT INTO ingest_jobs (job_id, doc_id, filename, status)
VALUES (?, ?, ?, 'queued')
""", (job_id, doc_id, file.filename))
cursor.execute("""
INSERT INTO documents (doc_id, job_id, path, filename, content_type)
VALUES (?, ?, ?, ?, ?)
""", (doc_id, job_id, str(file_path), file.filename, file.content_type))
conn.commit()
conn.close()
# Queue processing
background_tasks.add_task(
process_document,
job_id,
doc_id,
file_path,
user_message=cleaned_message
)
jobs.append({
'job_id': job_id,
'doc_id': doc_id,
'filename': file.filename,
'status': 'queued'
})
logger.info(f"✅ [{idx+1}/{len(files)}] Queued: {file.filename}")
if not jobs:
raise HTTPException(400, "No valid files to process")
# Save batch metadata
with FileLock(str(LOCK_PATH), timeout=10):
conn = sqlite3.connect(str(DB_PATH))
cursor = conn.cursor()
# Create batch_jobs table if it doesn't exist
cursor.execute("""
CREATE TABLE IF NOT EXISTS batch_jobs (
batch_id TEXT PRIMARY KEY,
total_files INTEGER,
message TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
cursor.execute("""
INSERT INTO batch_jobs (batch_id, total_files, message)
VALUES (?, ?, ?)
""", (batch_id, len(jobs), cleaned_message))
# Link jobs to batch
cursor.execute("""
CREATE TABLE IF NOT EXISTS batch_job_mapping (
batch_id TEXT,
job_id TEXT,
FOREIGN KEY (job_id) REFERENCES ingest_jobs(job_id)
)
""")
for job in jobs:
cursor.execute("""
INSERT INTO batch_job_mapping (batch_id, job_id)
VALUES (?, ?)
""", (batch_id, job['job_id']))
conn.commit()
conn.close()
mode = "autonomous agent"
if cleaned_message:
mode += " with intelligent filtering"
logger.info(f"🚀 Batch {batch_id} processing started with {len(jobs)} files")
return BatchIngestResponse(
batch_id=batch_id,
total_files=len(jobs),
jobs=jobs,
message=f'Batch of {len(jobs)} documents uploaded. Processing with {mode}.'
)
except HTTPException:
raise
except Exception as e:
logger.error(f"❌ Batch ingest error: {e}")
import traceback
logger.error(traceback.format_exc())
raise HTTPException(500, str(e))
@router.get("/ingest/batch/{batch_id}", response_model=BatchStatusResponse)
def get_batch_status(batch_id: str):
"""
Get status of all jobs in a batch.
Example:
curl http://localhost:7860/api/ingest/batch/batch_abc123
"""
try:
with FileLock(str(LOCK_PATH), timeout=10):
conn = sqlite3.connect(str(DB_PATH))
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
# Get batch info
cursor.execute("SELECT * FROM batch_jobs WHERE batch_id = ?", (batch_id,))
batch = cursor.fetchone()
if not batch:
conn.close()
raise HTTPException(404, "Batch not found")
# Get all jobs in batch
cursor.execute("""
SELECT j.* FROM ingest_jobs j
JOIN batch_job_mapping bm ON j.job_id = bm.job_id
WHERE bm.batch_id = ?
""", (batch_id,))
jobs = cursor.fetchall()
conn.close()
# Count statuses
status_counts = {
'completed': 0,
'processing': 0,
'failed': 0,
'queued': 0,
'needs_review': 0
}
jobs_list = []
for job in jobs:
job_dict = dict(job)
status = job_dict['status']
status_counts[status] = status_counts.get(status, 0) + 1
jobs_list.append({
'job_id': job_dict['job_id'],
'doc_id': job_dict['doc_id'],
'filename': job_dict['filename'],
'status': status,
'error_text': job_dict.get('error_text'),
'created_at': job_dict['created_at'],
'updated_at': job_dict['updated_at']
})
return BatchStatusResponse(
batch_id=batch_id,
total_files=len(jobs),
completed=status_counts['completed'],
processing=status_counts['processing'],
failed=status_counts['failed'],
queued=status_counts['queued'],
jobs=jobs_list
)
except HTTPException:
raise
except Exception as e:
logger.error(f"❌ Batch status error: {e}")
raise HTTPException(500, str(e))
@router.get("/ingest/batch/{batch_id}/download")
def download_batch_results(batch_id: str):
"""
Download all extracted data from a batch as CSV.
Example:
curl http://localhost:7860/api/ingest/batch/batch_abc123/download -o results.csv
"""
try:
import csv
from io import StringIO
from fastapi.responses import StreamingResponse
with FileLock(str(LOCK_PATH), timeout=10):
conn = sqlite3.connect(str(DB_PATH))
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
# Get all completed jobs in batch
cursor.execute("""
SELECT j.*, f.* FROM ingest_jobs j
JOIN batch_job_mapping bm ON j.job_id = bm.job_id
LEFT JOIN invoice_fields f ON j.doc_id = f.doc_id
WHERE bm.batch_id = ? AND j.status = 'completed'
""", (batch_id,))
results = cursor.fetchall()
conn.close()
if not results:
raise HTTPException(404, "No completed jobs found in batch")
# Create CSV
output = StringIO()
writer = csv.writer(output)
# Header
writer.writerow([
'filename', 'doc_id', 'customer', 'date', 'amount',
'payment_terms', 'business_code', 'status'
])
# Data rows
for row in results:
writer.writerow([
row['filename'],
row['doc_id'],
row['cust_number'] or 'N/A',
row['posting_date'] or 'N/A',
row['total_open_amount'] or 0.0,
row['cust_payment_terms'] or 'N/A',
row['business_code'] or 'N/A',
row['status']
])
output.seek(0)
return StreamingResponse(
iter([output.getvalue()]),
media_type="text/csv",
headers={
"Content-Disposition": f"attachment; filename=batch_{batch_id}_results.csv"
}
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(500, str(e))