Spaces:
Sleeping
Sleeping
| """ | |
| FastAPI Backend for Invoice OCR System | |
| """ | |
| # Setup Hugging Face credentials (must be first) | |
| try: | |
| import setup_hf_credentials | |
| except ImportError: | |
| pass # Not on Hugging Face, credentials already set | |
| import os | |
| import json | |
| import shutil | |
| from pathlib import Path | |
| from typing import List, Optional | |
| from datetime import datetime | |
| from fastapi import FastAPI, UploadFile, File, HTTPException, Depends, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import FileResponse, Response | |
| from sqlalchemy import create_engine, Column, Integer, String, Float, DateTime, JSON, Text, func | |
| from sqlalchemy.ext.declarative import declarative_base | |
| from sqlalchemy.orm import sessionmaker, Session | |
| from ocr_invoice import InvoiceOCR | |
| from cost_tracker import CostTracker | |
| # Removed auth for Vercel deployment | |
| # Initialize FastAPI | |
| app = FastAPI(title="Invoice OCR API", version="1.0.0") | |
| # CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Database setup | |
| if os.path.exists("/app"): | |
| # Hugging Face environment | |
| DATABASE_URL = "sqlite:////tmp/invoices.db" | |
| UPLOAD_DIR = Path("/tmp/uploads") | |
| else: | |
| # Local environment | |
| DATABASE_URL = "sqlite:///./invoices.db" | |
| UPLOAD_DIR = Path("./uploads") | |
| UPLOAD_DIR.mkdir(parents=True, exist_ok=True) | |
| engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False}) | |
| SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) | |
| Base = declarative_base() | |
| # Database Models | |
| class Invoice(Base): | |
| __tablename__ = "invoices" | |
| id = Column(Integer, primary_key=True, index=True) | |
| filename = Column(String, index=True) | |
| # Supplier info | |
| supplier_name = Column(String) | |
| supplier_address = Column(Text) | |
| # Customer info | |
| customer_name = Column(String) | |
| customer_address = Column(Text) | |
| # Invoice details | |
| invoice_number = Column(String, index=True) | |
| invoice_date = Column(String) | |
| due_date = Column(String) | |
| po_number = Column(String) | |
| payment_terms = Column(String) | |
| # Financial summary | |
| subtotal = Column(Float) | |
| tax_amount = Column(Float) | |
| total_amount = Column(Float) | |
| currency = Column(String) | |
| # Line items (stored as JSON text) | |
| line_items = Column(Text) | |
| # Additional data (stored as JSON text) | |
| supplier_data = Column(Text) | |
| customer_data = Column(Text) | |
| payment_info = Column(Text) | |
| additional_notes = Column(Text) | |
| raw_data = Column(Text) | |
| # Processing metadata | |
| processing_cost = Column(Float, default=0.0) | |
| created_at = Column(DateTime, default=datetime.utcnow) | |
| updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) | |
| # Create tables | |
| Base.metadata.create_all(bind=engine) | |
| # Dependency | |
| def get_db(): | |
| db = SessionLocal() | |
| try: | |
| yield db | |
| finally: | |
| db.close() | |
| # Initialize OCR and Cost Tracker | |
| ocr_processor = InvoiceOCR( | |
| project_id=os.getenv("PROJECT_ID"), | |
| location=os.getenv("LOCATION"), | |
| processor_id=os.getenv("PROCESSOR_ID"), | |
| gemini_api_key=os.getenv("GEMINI_API_KEY") | |
| ) | |
| cost_tracker = CostTracker() | |
| # API Routes | |
| async def root(): | |
| """Serve the main HTML page""" | |
| static_dir = Path(__file__).parent / "static" | |
| index_file = static_dir / "index.html" | |
| if index_file.exists(): | |
| return FileResponse(index_file) | |
| return {"message": "Invoice OCR API", "docs": "/docs"} | |
| async def favicon(): | |
| """Serve favicon to prevent 404 warnings""" | |
| static_dir = Path(__file__).parent / "static" | |
| favicon_file = static_dir / "favicon.ico" | |
| if favicon_file.exists(): | |
| return FileResponse(favicon_file) | |
| return Response(status_code=204) # No Content if favicon doesn't exist | |
| async def upload_invoice(file: UploadFile = File(...), db: Session = Depends(get_db)): | |
| """Upload and process an invoice""" | |
| file_path = None | |
| try: | |
| # Save uploaded file | |
| file_path = UPLOAD_DIR / file.filename | |
| with open(file_path, "wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| # Process with OCR | |
| print(f"Processing: {file.filename}") | |
| invoice_data = ocr_processor.process_invoice(str(file_path), save_json=False) | |
| if "error" in invoice_data: | |
| error_msg = invoice_data.get("error", "Unknown error") | |
| raw_response = invoice_data.get("raw_response", "") | |
| print(f"⚠ Invoice processing error: {error_msg}") | |
| if raw_response: | |
| print(f"Raw response: {raw_response[:500]}") | |
| raise HTTPException(status_code=500, detail=f"OCR Error: {error_msg}") | |
| # Extract metadata for cost calculation | |
| metadata = invoice_data.pop("_processing_metadata", {}) | |
| raw_text = metadata.get("raw_text", "") | |
| includes_image = metadata.get("includes_image", True) | |
| # Calculate processing cost | |
| costs = cost_tracker.calculate_invoice_cost( | |
| input_text=raw_text, | |
| output_text=json.dumps(invoice_data), | |
| includes_image=includes_image | |
| ) | |
| # Extract data | |
| supplier = invoice_data.get("supplier", {}) | |
| customer = invoice_data.get("customer", {}) | |
| inv_details = invoice_data.get("invoice_details", {}) | |
| financial = invoice_data.get("financial_summary", {}) | |
| line_items = invoice_data.get("line_items", []) | |
| # Calculate totals from line items if not provided | |
| if line_items: | |
| calculated_subtotal = sum(item.get("total_price", 0) for item in line_items) | |
| # If financial summary is missing or incomplete, calculate it | |
| if not financial or not isinstance(financial, dict): | |
| financial = {} | |
| # Use calculated subtotal if not provided | |
| if not financial.get("subtotal"): | |
| financial["subtotal"] = round(calculated_subtotal, 2) | |
| # Calculate tax if not provided (assume 0 if not specified) | |
| if not financial.get("tax_amount"): | |
| financial["tax_amount"] = 0.0 | |
| # Calculate total_amount if not provided | |
| if not financial.get("total_amount"): | |
| financial["total_amount"] = round(financial.get("subtotal", 0) + financial.get("tax_amount", 0), 2) | |
| # Set currency if not provided | |
| if not financial.get("currency"): | |
| financial["currency"] = "EUR" | |
| print(f"✓ Financial summary calculated:") | |
| print(f" Subtotal: {financial.get('subtotal')} (from {len(line_items)} line items)") | |
| print(f" Tax: {financial.get('tax_amount')}") | |
| print(f" Total: {financial.get('total_amount')}") | |
| # Handle both old format (nested objects) and new format (simple strings) | |
| # If supplier is a string, convert to object format | |
| if isinstance(supplier, str): | |
| supplier = {"name": supplier, "address": "", "phone": "", "email": "", "tax_id": "", "registration_number": ""} | |
| if isinstance(customer, str): | |
| customer = {"name": customer, "address": "", "phone": "", "email": ""} | |
| # Convert to JSON strings for storage | |
| line_items_json = json.dumps(line_items) | |
| supplier_json = json.dumps(supplier) | |
| customer_json = json.dumps(customer) | |
| payment_json = json.dumps(invoice_data.get("payment_info", {})) | |
| raw_json = json.dumps(invoice_data) | |
| # Save to database | |
| db_invoice = Invoice( | |
| filename=file.filename, | |
| supplier_name=supplier.get("name", "") if isinstance(supplier, dict) else str(supplier), | |
| supplier_address=supplier.get("address", "") if isinstance(supplier, dict) else "", | |
| customer_name=customer.get("name", "") if isinstance(customer, dict) else str(customer), | |
| customer_address=customer.get("address", "") if isinstance(customer, dict) else "", | |
| invoice_number=inv_details.get("invoice_number", "") if isinstance(inv_details, dict) else str(invoice_data.get("invoice_number", "")), | |
| invoice_date=inv_details.get("invoice_date") if isinstance(inv_details, dict) else invoice_data.get("invoice_date") or invoice_data.get("date"), | |
| due_date=inv_details.get("due_date") if isinstance(inv_details, dict) else None, | |
| po_number=inv_details.get("po_number") if isinstance(inv_details, dict) else None, | |
| payment_terms=inv_details.get("payment_terms") if isinstance(inv_details, dict) else None, | |
| subtotal=financial.get("subtotal") if isinstance(financial, dict) else None, | |
| tax_amount=financial.get("tax_amount") if isinstance(financial, dict) else None, | |
| total_amount=financial.get("total_amount") if isinstance(financial, dict) else None, | |
| currency=financial.get("currency", "") if isinstance(financial, dict) else "EUR", | |
| line_items=line_items_json, | |
| supplier_data=supplier_json, | |
| customer_data=customer_json, | |
| payment_info=payment_json, | |
| additional_notes=str(invoice_data.get("additional_notes", "")), | |
| raw_data=raw_json, | |
| processing_cost=costs["total"] | |
| ) | |
| db.add(db_invoice) | |
| db.commit() | |
| db.refresh(db_invoice) | |
| # Return response with proper cost structure | |
| return { | |
| "success": True, | |
| "invoice": { | |
| "id": db_invoice.id, | |
| "filename": db_invoice.filename, | |
| "supplier_data": supplier_json, | |
| "customer_data": customer_json, | |
| "invoice_details": json.dumps(inv_details), | |
| "line_items": line_items_json, | |
| "financial_summary": json.dumps(financial), | |
| "payment_info": payment_json, | |
| "additional_notes": db_invoice.additional_notes, | |
| "processing_cost": db_invoice.processing_cost | |
| }, | |
| "costs": { | |
| "total_cost": costs['total'], | |
| "document_ai_cost": costs['document_ai'], | |
| "gemini_input_cost": costs['gemini_input'], | |
| "gemini_output_cost": costs['gemini_output'], | |
| "gemini_input_tokens": costs['tokens']['input'], | |
| "gemini_output_tokens": costs['tokens']['output'] | |
| } | |
| } | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| import traceback | |
| print(f"⚠ Error processing invoice:") | |
| print(traceback.format_exc()) | |
| if file_path and file_path.exists(): | |
| try: | |
| file_path.unlink() | |
| except: | |
| pass | |
| raise HTTPException(status_code=500, detail=f"Processing failed: {str(e)}") | |
| async def get_invoices(limit: int = 10, db: Session = Depends(get_db)): | |
| """Get list of processed invoices""" | |
| invoices = db.query(Invoice).order_by(Invoice.created_at.desc()).limit(limit).all() | |
| result = [] | |
| for inv in invoices: | |
| try: | |
| line_items = json.loads(inv.line_items) if inv.line_items else [] | |
| items_count = len(line_items) | |
| except: | |
| items_count = 0 | |
| # Build invoice_details JSON | |
| invoice_details_json = json.dumps({ | |
| "invoice_number": inv.invoice_number, | |
| "invoice_date": inv.invoice_date, | |
| "due_date": inv.due_date, | |
| "po_number": inv.po_number, | |
| "payment_terms": inv.payment_terms | |
| }) | |
| # Build financial_summary JSON | |
| financial_summary_json = json.dumps({ | |
| "subtotal": inv.subtotal, | |
| "tax_amount": inv.tax_amount, | |
| "total_amount": inv.total_amount, | |
| "currency": inv.currency | |
| }) | |
| result.append({ | |
| "id": inv.id, | |
| "filename": inv.filename, | |
| "supplier_name": inv.supplier_name, | |
| "customer_name": inv.customer_name, | |
| "invoice_number": inv.invoice_number, | |
| "invoice_date": inv.invoice_date, | |
| "due_date": inv.due_date, | |
| "total_amount": inv.total_amount, | |
| "currency": inv.currency, | |
| "items_count": items_count, | |
| "processing_cost": inv.processing_cost, | |
| "created_at": inv.created_at.isoformat() if inv.created_at else None, | |
| # Add JSON fields needed by frontend | |
| "invoice_details": invoice_details_json, | |
| "financial_summary": financial_summary_json, | |
| "supplier_data": inv.supplier_data | |
| }) | |
| return result | |
| async def debug_invoice(invoice_id: int, db: Session = Depends(get_db)): | |
| """Debug endpoint to see raw extracted data""" | |
| invoice = db.query(Invoice).filter(Invoice.id == invoice_id).first() | |
| if not invoice: | |
| raise HTTPException(status_code=404, detail="Invoice not found") | |
| line_items = json.loads(invoice.line_items) if invoice.line_items else [] | |
| return { | |
| "filename": invoice.filename, | |
| "line_items": line_items, | |
| "items_count": len(line_items) | |
| } | |
| async def get_invoice(invoice_id: int, db: Session = Depends(get_db)): | |
| """Get detailed invoice data""" | |
| invoice = db.query(Invoice).filter(Invoice.id == invoice_id).first() | |
| if not invoice: | |
| raise HTTPException(status_code=404, detail="Invoice not found") | |
| # Build invoice_details as JSON string | |
| invoice_details_json = json.dumps({ | |
| "invoice_number": invoice.invoice_number, | |
| "invoice_date": invoice.invoice_date, | |
| "due_date": invoice.due_date, | |
| "po_number": invoice.po_number, | |
| "payment_terms": invoice.payment_terms | |
| }) | |
| # Build financial_summary as JSON string | |
| financial_summary_json = json.dumps({ | |
| "subtotal": invoice.subtotal, | |
| "tax_amount": invoice.tax_amount, | |
| "total_amount": invoice.total_amount, | |
| "currency": invoice.currency | |
| }) | |
| return { | |
| "invoice": { | |
| "id": invoice.id, | |
| "filename": invoice.filename, | |
| "supplier_data": invoice.supplier_data, | |
| "customer_data": invoice.customer_data, | |
| "invoice_details": invoice_details_json, | |
| "line_items": invoice.line_items, | |
| "financial_summary": financial_summary_json, | |
| "payment_info": invoice.payment_info, | |
| "currency": invoice.currency | |
| }, | |
| "costs": { | |
| "document_ai_cost": 0.0015, | |
| "gemini_input_tokens": 0, | |
| "gemini_input_cost": 0.0, | |
| "gemini_output_tokens": 0, | |
| "gemini_output_cost": 0.0, | |
| "total_cost": invoice.processing_cost | |
| } | |
| } | |
| async def delete_invoice(invoice_id: int, db: Session = Depends(get_db)): | |
| """Delete an invoice""" | |
| invoice = db.query(Invoice).filter(Invoice.id == invoice_id).first() | |
| if not invoice: | |
| raise HTTPException(status_code=404, detail="Invoice not found") | |
| db.delete(invoice) | |
| db.commit() | |
| return {"success": True, "message": "Invoice deleted"} | |
| async def get_stats(db: Session = Depends(get_db)): | |
| """Get statistics""" | |
| total_invoices = db.query(Invoice).count() | |
| total_amount = db.query(Invoice).with_entities( | |
| func.sum(Invoice.total_amount) | |
| ).scalar() or 0 | |
| total_cost = db.query(Invoice).with_entities( | |
| func.sum(Invoice.processing_cost) | |
| ).scalar() or 0 | |
| return { | |
| "total_invoices": total_invoices, | |
| "total_invoice_amount": round(total_amount, 2), | |
| "total_processing_cost": round(total_cost, 6), | |
| "average_cost_per_invoice": round(total_cost / total_invoices, 6) if total_invoices > 0 else 0 | |
| } | |
| # Mount static files | |
| static_dir = Path(__file__).parent / "static" | |
| if static_dir.exists(): | |
| app.mount("/static", StaticFiles(directory=str(static_dir)), name="static") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |