Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException, UploadFile, File, Form | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import base64 | |
| import uvicorn | |
| import os | |
| from dotenv import load_dotenv | |
| from typing import List | |
| # Import your LangGraph receipt scanner agent | |
| from receipt_gen_agent import receipt_agent | |
| HF_SPACE_URL = os.environ.get("SPACE_URL", "") | |
| app = FastAPI( | |
| title="Receipt Scanner API", | |
| root_path="/" if not HF_SPACE_URL else HF_SPACE_URL | |
| ) | |
| # app = FastAPI() | |
| # Configure CORS to allow requests from your React app | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # In production, specify your React app's origin | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Data model for the response | |
| class ReceiptItem(BaseModel): | |
| name: str | |
| price: float | |
| quantity: int = 1 | |
| class ReceiptResponse(BaseModel): | |
| store: str | |
| date: str | |
| total: float | |
| items: list[ReceiptItem] | |
| class ImageData(BaseModel): | |
| images: List[str] | |
| # Initialize the receipt agent | |
| receipt_scanner = receipt_agent() | |
| # Format the response from the LangGraph agent to match our API response format | |
| def format_receipt_data(agent_response): | |
| """ | |
| Format the receipt data from the LangGraph agent to match our API response format | |
| """ | |
| # Ensure we have values for all our fields | |
| store = agent_response.get('loc_name', '') | |
| date = agent_response.get('date', '') | |
| total = agent_response.get('total', 0) | |
| items = agent_response.get('items', []) | |
| # Format the items to match our expected format | |
| formatted_items = [] | |
| for item in items: | |
| formatted_items.append({ | |
| "name": item.get('name', 'Unknown Item'), | |
| "price": item.get('price', 0.0), | |
| "quantity": item.get('quantity', 1) | |
| }) | |
| # Return the formatted data | |
| return { | |
| "store": store, | |
| "date": date, | |
| "total": total, | |
| "items": formatted_items | |
| } | |
| # Function to scan receipt using the LangGraph agent | |
| def scan_receipt(images): | |
| """ | |
| Process multiple images using the LangGraph receipt scanner agent | |
| Returns a dictionary of receipt data | |
| """ | |
| try: | |
| # Use the receipt_gen method from the agent to process the images | |
| agent_response = receipt_scanner.receipt_gen(images) | |
| # Format the response to match our API format | |
| formatted_response = format_receipt_data(agent_response) | |
| return formatted_response | |
| except Exception as e: | |
| print(f"Error in receipt scanning: {str(e)}") | |
| # Fallback to default data in case of an error | |
| return { | |
| "store": "Receipt Scanner", | |
| "date": "2025-03-24", | |
| "total": 0, | |
| "items": [ | |
| {"name": "Unable to scan receipt", "price": 0, "quantity": 1} | |
| ] | |
| } | |
| async def process_receipt( | |
| images: list[UploadFile] = File(...) | |
| ): | |
| try: | |
| base64_images = [] | |
| for i in images: | |
| print(f"Processing file: {i.filename}") | |
| # Read the file content | |
| contents = await i.read() | |
| print(f"File size: {len(contents)} bytes") | |
| # Convert to base64 without trying to decode as UTF-8 | |
| base64_image = base64.b64encode(contents).decode('utf-8') | |
| base64_images.append(base64_image) | |
| print("Sending image to receipt scanner...") | |
| # Process image with the LangGraph receipt scanner | |
| result = scan_receipt(base64_images) | |
| print("Receipt scanning completed successfully") | |
| # Return the result | |
| return result | |
| except Exception as e: | |
| print(f"API error: {str(e)}") | |
| import traceback | |
| print(f"Traceback: {traceback.format_exc()}") | |
| raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}") | |
| async def root(): | |
| return { | |
| "message": "Receipt Scanner API is running", | |
| "endpoints": { | |
| "scan": "/api/scan", | |
| "health": "/health" | |
| } | |
| } | |
| # Health check endpoint | |
| async def health_check(): | |
| return {"status": "ok"} | |
| # Environment setup | |
| load_dotenv() | |
| if not os.getenv('google_api_key'): | |
| print("WARNING: google_api_key not found in environment variables") | |
| print("Please ensure you have set up your .env file with the Google API key") | |
| if __name__ == "__main__": | |
| # Get port from environment variable or default to 8000 | |
| port = int(os.environ.get("PORT", 8000)) | |
| print(f"Starting FastAPI server on port {port}") | |
| print("Initializing receipt scanner agent...") | |
| print("Server ready to process receipt images") | |
| uvicorn.run("app:app", host="0.0.0.0", port=port, reload=True) |