Spaces:
Sleeping
Sleeping
| """ | |
| FastAPI Backend for OCR Receipt Processing | |
| With SQLite database and Hugging Face Spaces deployment ready | |
| """ | |
| from fastapi import FastAPI, File, UploadFile, HTTPException, Depends | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse, FileResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from sqlalchemy import create_engine, Column, Integer, String, Float, DateTime, JSON, Text | |
| from sqlalchemy.ext.declarative import declarative_base | |
| from sqlalchemy.orm import sessionmaker, Session | |
| from datetime import datetime | |
| from typing import List, Optional | |
| import os | |
| import shutil | |
| import json | |
| from pathlib import Path | |
| from dotenv import load_dotenv | |
| # Load environment variables | |
| load_dotenv() | |
| # Import our OCR processor | |
| from ocr_with_gemini import EnhancedReceiptOCR | |
| from cost_tracker import CostTracker | |
| # Initialize cost tracker | |
| cost_tracker = CostTracker() | |
| # Database setup - use /tmp on Hugging Face for write access | |
| if os.path.exists("/app"): | |
| # Running on Hugging Face | |
| SQLALCHEMY_DATABASE_URL = "sqlite:////tmp/receipts.db" | |
| else: | |
| # Running locally | |
| SQLALCHEMY_DATABASE_URL = "sqlite:///./receipts.db" | |
| engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}) | |
| SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) | |
| Base = declarative_base() | |
| # Create uploads directory (use /tmp on Hugging Face for write access) | |
| if os.path.exists("/app"): | |
| # Running on Hugging Face - use /tmp for uploads | |
| UPLOAD_DIR = Path("/tmp/uploads") | |
| STATIC_DIR = Path("static") # static should already exist in repo | |
| else: | |
| # Running locally | |
| UPLOAD_DIR = Path("uploads") | |
| STATIC_DIR = Path("static") | |
| # Only create directories if they don't exist and we have permission | |
| try: | |
| UPLOAD_DIR.mkdir(exist_ok=True, parents=True) | |
| except PermissionError: | |
| print(f"Warning: Could not create upload directory at {UPLOAD_DIR}") | |
| try: | |
| if not STATIC_DIR.exists(): | |
| STATIC_DIR.mkdir(exist_ok=True, parents=True) | |
| except PermissionError: | |
| print(f"Warning: Static directory doesn't exist at {STATIC_DIR}") | |
| # Database Models | |
| class Receipt(Base): | |
| __tablename__ = "receipts" | |
| id = Column(Integer, primary_key=True, index=True) | |
| filename = Column(String, index=True) | |
| merchant_name = Column(String, index=True) | |
| merchant_address = Column(Text) | |
| date = Column(String) | |
| time = Column(String) | |
| total = Column(Float) | |
| subtotal = Column(Float) | |
| tax = Column(Float) | |
| discount = Column(Float) | |
| currency = Column(String) | |
| payment_method = Column(String) | |
| receipt_number = Column(String) | |
| items = Column(JSON) # Store items as JSON | |
| additional_info = Column(JSON) # Store additional info as JSON | |
| raw_data = Column(JSON) # Store complete raw data | |
| processing_cost = Column(Float, default=0.0) # Cost to process this receipt | |
| created_at = Column(DateTime, default=datetime.utcnow) | |
| updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) | |
| # Create tables | |
| Base.metadata.create_all(bind=engine) | |
| # Initialize FastAPI | |
| app = FastAPI( | |
| title="Receipt OCR API", | |
| description="Advanced OCR API for receipts and bills using Document AI + Gemini AI", | |
| version="1.0.0" | |
| ) | |
| # CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Mount static files | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| # Initialize OCR processor - Simple initialization | |
| try: | |
| ocr_processor = EnhancedReceiptOCR( | |
| project_id=os.getenv("PROJECT_ID"), | |
| location=os.getenv("LOCATION"), | |
| processor_id=os.getenv("PROCESSOR_ID"), | |
| gemini_api_key=os.getenv("GEMINI_API_KEY") | |
| ) | |
| print("✓ OCR processor initialized successfully") | |
| except Exception as e: | |
| print(f"Warning: OCR processor initialization failed: {e}") | |
| ocr_processor = None | |
| # Dependency | |
| def get_db(): | |
| db = SessionLocal() | |
| try: | |
| yield db | |
| finally: | |
| db.close() | |
| # Routes | |
| async def root(): | |
| """Serve the main HTML interface""" | |
| html_file = STATIC_DIR / "index.html" | |
| if html_file.exists(): | |
| return FileResponse(html_file) | |
| return { | |
| "status": "healthy", | |
| "service": "Receipt OCR API", | |
| "version": "1.0.0", | |
| "endpoints": { | |
| "upload": "/upload - POST - Upload receipt image", | |
| "receipts": "/receipts - GET - List all receipts", | |
| "receipt": "/receipts/{id} - GET - Get specific receipt", | |
| "delete": "/receipts/{id} - DELETE - Delete receipt", | |
| "stats": "/stats - GET - Get statistics" | |
| } | |
| } | |
| async def upload_receipt( | |
| file: UploadFile = File(...), | |
| db: Session = Depends(get_db) | |
| ): | |
| """ | |
| Upload and process a receipt image | |
| Args: | |
| file: Receipt image file (JPG, PNG, PDF, etc.) | |
| Returns: | |
| Processed receipt data with database ID | |
| """ | |
| if not ocr_processor: | |
| raise HTTPException(status_code=500, detail="OCR processor not initialized") | |
| # Validate file type | |
| allowed_extensions = {".jpg", ".jpeg", ".png", ".pdf", ".tiff", ".tif", ".webp", ".bmp"} | |
| file_extension = Path(file.filename).suffix.lower() | |
| if file_extension not in allowed_extensions: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Invalid file type. Allowed: {', '.join(allowed_extensions)}" | |
| ) | |
| try: | |
| # Save uploaded file | |
| file_path = UPLOAD_DIR / file.filename | |
| with open(file_path, "wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| # Process with OCR (always uses Document AI + Gemini) | |
| print(f"Processing: {file.filename}") | |
| receipt_data = ocr_processor.process_receipt(str(file_path), save_json=False) | |
| if "error" in receipt_data: | |
| raise HTTPException(status_code=500, detail=receipt_data["error"]) | |
| # Calculate processing cost | |
| costs = cost_tracker.calculate_receipt_cost( | |
| output_text=json.dumps(receipt_data) | |
| ) | |
| # Save to database | |
| db_receipt = Receipt( | |
| filename=file.filename, | |
| merchant_name=receipt_data.get("merchant_name", ""), | |
| merchant_address=receipt_data.get("merchant_address", ""), | |
| date=receipt_data.get("date"), | |
| time=receipt_data.get("time"), | |
| total=receipt_data.get("total"), | |
| subtotal=receipt_data.get("subtotal"), | |
| tax=receipt_data.get("tax"), | |
| discount=receipt_data.get("discount", 0.0), | |
| currency=receipt_data.get("currency", ""), | |
| payment_method=receipt_data.get("payment_method", ""), | |
| receipt_number=receipt_data.get("receipt_number", ""), | |
| items=receipt_data.get("items", []), | |
| additional_info=receipt_data.get("additional_info", {}), | |
| raw_data=receipt_data, | |
| processing_cost=costs["total"] | |
| ) | |
| db.add(db_receipt) | |
| db.commit() | |
| db.refresh(db_receipt) | |
| return { | |
| "success": True, | |
| "message": "Receipt processed successfully", | |
| "receipt_id": db_receipt.id, | |
| "data": receipt_data, | |
| "processing_cost": { | |
| "total": f"${costs['total']:.6f}", | |
| "document_ai": f"${costs['document_ai']:.6f}", | |
| "gemini": f"${costs['gemini_total']:.6f}", | |
| "tokens": costs['tokens'] | |
| } | |
| } | |
| except Exception as e: | |
| # Clean up file on error | |
| if file_path.exists(): | |
| file_path.unlink() | |
| raise HTTPException(status_code=500, detail=f"Processing failed: {str(e)}") | |
| async def list_receipts( | |
| skip: int = 0, | |
| limit: int = 100, | |
| merchant: Optional[str] = None, | |
| date_from: Optional[str] = None, | |
| date_to: Optional[str] = None, | |
| db: Session = Depends(get_db) | |
| ): | |
| """ | |
| List all receipts with optional filters | |
| Args: | |
| skip: Number of records to skip | |
| limit: Maximum number of records to return | |
| merchant: Filter by merchant name | |
| date_from: Filter by date (YYYY-MM-DD) | |
| date_to: Filter by date (YYYY-MM-DD) | |
| """ | |
| query = db.query(Receipt) | |
| if merchant: | |
| query = query.filter(Receipt.merchant_name.contains(merchant)) | |
| if date_from: | |
| query = query.filter(Receipt.date >= date_from) | |
| if date_to: | |
| query = query.filter(Receipt.date <= date_to) | |
| total = query.count() | |
| receipts = query.order_by(Receipt.created_at.desc()).offset(skip).limit(limit).all() | |
| return { | |
| "total": total, | |
| "skip": skip, | |
| "limit": limit, | |
| "receipts": [ | |
| { | |
| "id": r.id, | |
| "filename": r.filename, | |
| "merchant_name": r.merchant_name, | |
| "date": r.date, | |
| "time": r.time, | |
| "total": r.total, | |
| "currency": r.currency, | |
| "items_count": len(r.items) if r.items else 0, | |
| "created_at": r.created_at.isoformat() | |
| } | |
| for r in receipts | |
| ] | |
| } | |
| async def get_receipt(receipt_id: int, db: Session = Depends(get_db)): | |
| """Get detailed information about a specific receipt""" | |
| receipt = db.query(Receipt).filter(Receipt.id == receipt_id).first() | |
| if not receipt: | |
| raise HTTPException(status_code=404, detail="Receipt not found") | |
| return { | |
| "id": receipt.id, | |
| "filename": receipt.filename, | |
| "merchant_name": receipt.merchant_name, | |
| "merchant_address": receipt.merchant_address, | |
| "date": receipt.date, | |
| "time": receipt.time, | |
| "total": receipt.total, | |
| "subtotal": receipt.subtotal, | |
| "tax": receipt.tax, | |
| "discount": receipt.discount, | |
| "currency": receipt.currency, | |
| "payment_method": receipt.payment_method, | |
| "receipt_number": receipt.receipt_number, | |
| "items": receipt.items, | |
| "additional_info": receipt.additional_info, | |
| "processing_cost": receipt.processing_cost, | |
| "created_at": receipt.created_at.isoformat(), | |
| "updated_at": receipt.updated_at.isoformat() | |
| } | |
| async def delete_receipt(receipt_id: int, db: Session = Depends(get_db)): | |
| """Delete a receipt from the database""" | |
| receipt = db.query(Receipt).filter(Receipt.id == receipt_id).first() | |
| if not receipt: | |
| raise HTTPException(status_code=404, detail="Receipt not found") | |
| # Delete associated file | |
| file_path = UPLOAD_DIR / receipt.filename | |
| if file_path.exists(): | |
| file_path.unlink() | |
| db.delete(receipt) | |
| db.commit() | |
| return {"success": True, "message": f"Receipt {receipt_id} deleted"} | |
| async def get_statistics(db: Session = Depends(get_db)): | |
| """Get statistics about receipts""" | |
| from sqlalchemy import func | |
| total_receipts = db.query(func.count(Receipt.id)).scalar() | |
| total_amount = db.query(func.sum(Receipt.total)).scalar() or 0 | |
| # Top merchants | |
| top_merchants = db.query( | |
| Receipt.merchant_name, | |
| func.count(Receipt.id).label("count"), | |
| func.sum(Receipt.total).label("total") | |
| ).group_by(Receipt.merchant_name).order_by(func.count(Receipt.id).desc()).limit(10).all() | |
| # By currency | |
| by_currency = db.query( | |
| Receipt.currency, | |
| func.count(Receipt.id).label("count"), | |
| func.sum(Receipt.total).label("total") | |
| ).group_by(Receipt.currency).all() | |
| return { | |
| "total_receipts": total_receipts, | |
| "total_amount": round(total_amount, 2), | |
| "top_merchants": [ | |
| { | |
| "merchant": m.merchant_name, | |
| "receipt_count": m.count, | |
| "total_spent": round(m.total or 0, 2) | |
| } | |
| for m in top_merchants if m.merchant_name | |
| ], | |
| "by_currency": [ | |
| { | |
| "currency": c.currency, | |
| "receipt_count": c.count, | |
| "total": round(c.total or 0, 2) | |
| } | |
| for c in by_currency if c.currency | |
| ] | |
| } | |
| async def search_receipts( | |
| q: str, | |
| db: Session = Depends(get_db) | |
| ): | |
| """ | |
| Search receipts by merchant name, items, or receipt number | |
| Args: | |
| q: Search query | |
| """ | |
| receipts = db.query(Receipt).filter( | |
| (Receipt.merchant_name.contains(q)) | | |
| (Receipt.merchant_address.contains(q)) | | |
| (Receipt.receipt_number.contains(q)) | |
| ).order_by(Receipt.created_at.desc()).limit(50).all() | |
| return { | |
| "query": q, | |
| "count": len(receipts), | |
| "receipts": [ | |
| { | |
| "id": r.id, | |
| "filename": r.filename, | |
| "merchant_name": r.merchant_name, | |
| "date": r.date, | |
| "total": r.total, | |
| "currency": r.currency, | |
| "created_at": r.created_at.isoformat() | |
| } | |
| for r in receipts | |
| ] | |
| } | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |