wolf1997's picture
Update app.py
d0c1517 verified
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import base64
import uvicorn
import os
from dotenv import load_dotenv
from typing import List
# Import your LangGraph receipt scanner agent
from receipt_gen_agent import receipt_agent
HF_SPACE_URL = os.environ.get("SPACE_URL", "")
app = FastAPI(
title="Receipt Scanner API",
root_path="/" if not HF_SPACE_URL else HF_SPACE_URL
)
# app = FastAPI()
# Configure CORS to allow requests from your React app
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, specify your React app's origin
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Data model for the response
class ReceiptItem(BaseModel):
name: str
price: float
quantity: int = 1
class ReceiptResponse(BaseModel):
store: str
date: str
total: float
items: list[ReceiptItem]
class ImageData(BaseModel):
images: List[str]
# Initialize the receipt agent
receipt_scanner = receipt_agent()
# Format the response from the LangGraph agent to match our API response format
def format_receipt_data(agent_response):
"""
Format the receipt data from the LangGraph agent to match our API response format
"""
# Ensure we have values for all our fields
store = agent_response.get('loc_name', '')
date = agent_response.get('date', '')
total = agent_response.get('total', 0)
items = agent_response.get('items', [])
# Format the items to match our expected format
formatted_items = []
for item in items:
formatted_items.append({
"name": item.get('name', 'Unknown Item'),
"price": item.get('price', 0.0),
"quantity": item.get('quantity', 1)
})
# Return the formatted data
return {
"store": store,
"date": date,
"total": total,
"items": formatted_items
}
# Function to scan receipt using the LangGraph agent
def scan_receipt(images):
"""
Process multiple images using the LangGraph receipt scanner agent
Returns a dictionary of receipt data
"""
try:
# Use the receipt_gen method from the agent to process the images
agent_response = receipt_scanner.receipt_gen(images)
# Format the response to match our API format
formatted_response = format_receipt_data(agent_response)
return formatted_response
except Exception as e:
print(f"Error in receipt scanning: {str(e)}")
# Fallback to default data in case of an error
return {
"store": "Receipt Scanner",
"date": "2025-03-24",
"total": 0,
"items": [
{"name": "Unable to scan receipt", "price": 0, "quantity": 1}
]
}
@app.post("/api/scan")
async def process_receipt(
images: list[UploadFile] = File(...)
):
try:
base64_images = []
for i in images:
print(f"Processing file: {i.filename}")
# Read the file content
contents = await i.read()
print(f"File size: {len(contents)} bytes")
# Convert to base64 without trying to decode as UTF-8
base64_image = base64.b64encode(contents).decode('utf-8')
base64_images.append(base64_image)
print("Sending image to receipt scanner...")
# Process image with the LangGraph receipt scanner
result = scan_receipt(base64_images)
print("Receipt scanning completed successfully")
# Return the result
return result
except Exception as e:
print(f"API error: {str(e)}")
import traceback
print(f"Traceback: {traceback.format_exc()}")
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
@app.get("/")
async def root():
return {
"message": "Receipt Scanner API is running",
"endpoints": {
"scan": "/api/scan",
"health": "/health"
}
}
# Health check endpoint
@app.get("/health")
async def health_check():
return {"status": "ok"}
# Environment setup
load_dotenv()
if not os.getenv('google_api_key'):
print("WARNING: google_api_key not found in environment variables")
print("Please ensure you have set up your .env file with the Google API key")
if __name__ == "__main__":
# Get port from environment variable or default to 8000
port = int(os.environ.get("PORT", 8000))
print(f"Starting FastAPI server on port {port}")
print("Initializing receipt scanner agent...")
print("Server ready to process receipt images")
uvicorn.run("app:app", host="0.0.0.0", port=port, reload=True)