ocr_myfinance / app_clean.py
Mariem-Daha's picture
Upload 7 files
8bcd133 verified
"""
FastAPI Backend for OCR Receipt Processing
With SQLite database and Hugging Face Spaces deployment ready
"""
from fastapi import FastAPI, File, UploadFile, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from sqlalchemy import create_engine, Column, Integer, String, Float, DateTime, JSON, Text
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session
from datetime import datetime
from typing import List, Optional
import os
import shutil
import json
from pathlib import Path
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# Import our OCR processor
from ocr_with_gemini import EnhancedReceiptOCR
from cost_tracker import CostTracker
# Initialize cost tracker
cost_tracker = CostTracker()
# Database setup - use /tmp on Hugging Face for write access
if os.path.exists("/app"):
# Running on Hugging Face
SQLALCHEMY_DATABASE_URL = "sqlite:////tmp/receipts.db"
else:
# Running locally
SQLALCHEMY_DATABASE_URL = "sqlite:///./receipts.db"
engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False})
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
# Create uploads directory (use /tmp on Hugging Face for write access)
if os.path.exists("/app"):
# Running on Hugging Face - use /tmp for uploads
UPLOAD_DIR = Path("/tmp/uploads")
STATIC_DIR = Path("static") # static should already exist in repo
else:
# Running locally
UPLOAD_DIR = Path("uploads")
STATIC_DIR = Path("static")
# Only create directories if they don't exist and we have permission
try:
UPLOAD_DIR.mkdir(exist_ok=True, parents=True)
except PermissionError:
print(f"Warning: Could not create upload directory at {UPLOAD_DIR}")
try:
if not STATIC_DIR.exists():
STATIC_DIR.mkdir(exist_ok=True, parents=True)
except PermissionError:
print(f"Warning: Static directory doesn't exist at {STATIC_DIR}")
# Database Models
class Receipt(Base):
__tablename__ = "receipts"
id = Column(Integer, primary_key=True, index=True)
filename = Column(String, index=True)
merchant_name = Column(String, index=True)
merchant_address = Column(Text)
date = Column(String)
time = Column(String)
total = Column(Float)
subtotal = Column(Float)
tax = Column(Float)
discount = Column(Float)
currency = Column(String)
payment_method = Column(String)
receipt_number = Column(String)
items = Column(JSON) # Store items as JSON
additional_info = Column(JSON) # Store additional info as JSON
raw_data = Column(JSON) # Store complete raw data
processing_cost = Column(Float, default=0.0) # Cost to process this receipt
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
# Create tables
Base.metadata.create_all(bind=engine)
# Initialize FastAPI
app = FastAPI(
title="Receipt OCR API",
description="Advanced OCR API for receipts and bills using Document AI + Gemini AI",
version="1.0.0"
)
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Mount static files
app.mount("/static", StaticFiles(directory="static"), name="static")
# Initialize OCR processor - Simple initialization
try:
ocr_processor = EnhancedReceiptOCR(
project_id=os.getenv("PROJECT_ID"),
location=os.getenv("LOCATION"),
processor_id=os.getenv("PROCESSOR_ID"),
gemini_api_key=os.getenv("GEMINI_API_KEY")
)
print("✓ OCR processor initialized successfully")
except Exception as e:
print(f"Warning: OCR processor initialization failed: {e}")
ocr_processor = None
# Dependency
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
# Routes
@app.get("/")
async def root():
"""Serve the main HTML interface"""
html_file = STATIC_DIR / "index.html"
if html_file.exists():
return FileResponse(html_file)
return {
"status": "healthy",
"service": "Receipt OCR API",
"version": "1.0.0",
"endpoints": {
"upload": "/upload - POST - Upload receipt image",
"receipts": "/receipts - GET - List all receipts",
"receipt": "/receipts/{id} - GET - Get specific receipt",
"delete": "/receipts/{id} - DELETE - Delete receipt",
"stats": "/stats - GET - Get statistics"
}
}
@app.post("/upload")
async def upload_receipt(
file: UploadFile = File(...),
db: Session = Depends(get_db)
):
"""
Upload and process a receipt image
Args:
file: Receipt image file (JPG, PNG, PDF, etc.)
Returns:
Processed receipt data with database ID
"""
if not ocr_processor:
raise HTTPException(status_code=500, detail="OCR processor not initialized")
# Validate file type
allowed_extensions = {".jpg", ".jpeg", ".png", ".pdf", ".tiff", ".tif", ".webp", ".bmp"}
file_extension = Path(file.filename).suffix.lower()
if file_extension not in allowed_extensions:
raise HTTPException(
status_code=400,
detail=f"Invalid file type. Allowed: {', '.join(allowed_extensions)}"
)
try:
# Save uploaded file
file_path = UPLOAD_DIR / file.filename
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
# Process with OCR (always uses Document AI + Gemini)
print(f"Processing: {file.filename}")
receipt_data = ocr_processor.process_receipt(str(file_path), save_json=False)
if "error" in receipt_data:
raise HTTPException(status_code=500, detail=receipt_data["error"])
# Calculate processing cost
costs = cost_tracker.calculate_receipt_cost(
output_text=json.dumps(receipt_data)
)
# Save to database
db_receipt = Receipt(
filename=file.filename,
merchant_name=receipt_data.get("merchant_name", ""),
merchant_address=receipt_data.get("merchant_address", ""),
date=receipt_data.get("date"),
time=receipt_data.get("time"),
total=receipt_data.get("total"),
subtotal=receipt_data.get("subtotal"),
tax=receipt_data.get("tax"),
discount=receipt_data.get("discount", 0.0),
currency=receipt_data.get("currency", ""),
payment_method=receipt_data.get("payment_method", ""),
receipt_number=receipt_data.get("receipt_number", ""),
items=receipt_data.get("items", []),
additional_info=receipt_data.get("additional_info", {}),
raw_data=receipt_data,
processing_cost=costs["total"]
)
db.add(db_receipt)
db.commit()
db.refresh(db_receipt)
return {
"success": True,
"message": "Receipt processed successfully",
"receipt_id": db_receipt.id,
"data": receipt_data,
"processing_cost": {
"total": f"${costs['total']:.6f}",
"document_ai": f"${costs['document_ai']:.6f}",
"gemini": f"${costs['gemini_total']:.6f}",
"tokens": costs['tokens']
}
}
except Exception as e:
# Clean up file on error
if file_path.exists():
file_path.unlink()
raise HTTPException(status_code=500, detail=f"Processing failed: {str(e)}")
@app.get("/receipts")
async def list_receipts(
skip: int = 0,
limit: int = 100,
merchant: Optional[str] = None,
date_from: Optional[str] = None,
date_to: Optional[str] = None,
db: Session = Depends(get_db)
):
"""
List all receipts with optional filters
Args:
skip: Number of records to skip
limit: Maximum number of records to return
merchant: Filter by merchant name
date_from: Filter by date (YYYY-MM-DD)
date_to: Filter by date (YYYY-MM-DD)
"""
query = db.query(Receipt)
if merchant:
query = query.filter(Receipt.merchant_name.contains(merchant))
if date_from:
query = query.filter(Receipt.date >= date_from)
if date_to:
query = query.filter(Receipt.date <= date_to)
total = query.count()
receipts = query.order_by(Receipt.created_at.desc()).offset(skip).limit(limit).all()
return {
"total": total,
"skip": skip,
"limit": limit,
"receipts": [
{
"id": r.id,
"filename": r.filename,
"merchant_name": r.merchant_name,
"date": r.date,
"time": r.time,
"total": r.total,
"currency": r.currency,
"items_count": len(r.items) if r.items else 0,
"created_at": r.created_at.isoformat()
}
for r in receipts
]
}
@app.get("/receipts/{receipt_id}")
async def get_receipt(receipt_id: int, db: Session = Depends(get_db)):
"""Get detailed information about a specific receipt"""
receipt = db.query(Receipt).filter(Receipt.id == receipt_id).first()
if not receipt:
raise HTTPException(status_code=404, detail="Receipt not found")
return {
"id": receipt.id,
"filename": receipt.filename,
"merchant_name": receipt.merchant_name,
"merchant_address": receipt.merchant_address,
"date": receipt.date,
"time": receipt.time,
"total": receipt.total,
"subtotal": receipt.subtotal,
"tax": receipt.tax,
"discount": receipt.discount,
"currency": receipt.currency,
"payment_method": receipt.payment_method,
"receipt_number": receipt.receipt_number,
"items": receipt.items,
"additional_info": receipt.additional_info,
"processing_cost": receipt.processing_cost,
"created_at": receipt.created_at.isoformat(),
"updated_at": receipt.updated_at.isoformat()
}
@app.delete("/receipts/{receipt_id}")
async def delete_receipt(receipt_id: int, db: Session = Depends(get_db)):
"""Delete a receipt from the database"""
receipt = db.query(Receipt).filter(Receipt.id == receipt_id).first()
if not receipt:
raise HTTPException(status_code=404, detail="Receipt not found")
# Delete associated file
file_path = UPLOAD_DIR / receipt.filename
if file_path.exists():
file_path.unlink()
db.delete(receipt)
db.commit()
return {"success": True, "message": f"Receipt {receipt_id} deleted"}
@app.get("/stats")
async def get_statistics(db: Session = Depends(get_db)):
"""Get statistics about receipts"""
from sqlalchemy import func
total_receipts = db.query(func.count(Receipt.id)).scalar()
total_amount = db.query(func.sum(Receipt.total)).scalar() or 0
# Top merchants
top_merchants = db.query(
Receipt.merchant_name,
func.count(Receipt.id).label("count"),
func.sum(Receipt.total).label("total")
).group_by(Receipt.merchant_name).order_by(func.count(Receipt.id).desc()).limit(10).all()
# By currency
by_currency = db.query(
Receipt.currency,
func.count(Receipt.id).label("count"),
func.sum(Receipt.total).label("total")
).group_by(Receipt.currency).all()
return {
"total_receipts": total_receipts,
"total_amount": round(total_amount, 2),
"top_merchants": [
{
"merchant": m.merchant_name,
"receipt_count": m.count,
"total_spent": round(m.total or 0, 2)
}
for m in top_merchants if m.merchant_name
],
"by_currency": [
{
"currency": c.currency,
"receipt_count": c.count,
"total": round(c.total or 0, 2)
}
for c in by_currency if c.currency
]
}
@app.get("/search")
async def search_receipts(
q: str,
db: Session = Depends(get_db)
):
"""
Search receipts by merchant name, items, or receipt number
Args:
q: Search query
"""
receipts = db.query(Receipt).filter(
(Receipt.merchant_name.contains(q)) |
(Receipt.merchant_address.contains(q)) |
(Receipt.receipt_number.contains(q))
).order_by(Receipt.created_at.desc()).limit(50).all()
return {
"query": q,
"count": len(receipts),
"receipts": [
{
"id": r.id,
"filename": r.filename,
"merchant_name": r.merchant_name,
"date": r.date,
"total": r.total,
"currency": r.currency,
"created_at": r.created_at.isoformat()
}
for r in receipts
]
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)