Spaces:
Runtime error
Runtime error
Franko Fišter commited on
Commit ·
c1c1a13
1
Parent(s): bcd0a9b
Working Supabase connection and receipt request processing
Browse files- api/receipt_routes.py +41 -3
- config/settings.py +12 -0
- db/receipt_repository.py +82 -0
- db/supabase_client.py +21 -0
- requirements.txt +3 -1
- utils/rate_limiter.py +25 -0
api/receipt_routes.py
CHANGED
|
@@ -1,28 +1,66 @@
|
|
| 1 |
-
from fastapi import APIRouter, File, UploadFile, HTTPException
|
| 2 |
from receipt_processor.google_ocr import GoogleVisionOCR
|
| 3 |
from receipt_processor.parsers.parser_selector import ParserSelector
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
# Initialize OCR and parser selector
|
|
|
|
|
|
|
| 6 |
ocr_processor = GoogleVisionOCR()
|
| 7 |
parser_selector = ParserSelector()
|
| 8 |
|
| 9 |
router = APIRouter(prefix="/receipts", tags=["Receipt Processing"])
|
| 10 |
|
| 11 |
@router.post("/scan")
|
| 12 |
-
async def process_receipt(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
try:
|
|
|
|
|
|
|
|
|
|
| 14 |
content = await file.read()
|
|
|
|
|
|
|
|
|
|
| 15 |
extracted_text = ocr_processor.extract_text(content)
|
|
|
|
| 16 |
|
| 17 |
if not extracted_text:
|
| 18 |
raise HTTPException(400, "No text extracted from image")
|
| 19 |
|
|
|
|
| 20 |
parser = parser_selector.get_store_parser(extracted_text)
|
|
|
|
|
|
|
| 21 |
parsed_receipt = parser.parse(extracted_text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
return {
|
| 24 |
"status": "success",
|
| 25 |
-
"
|
| 26 |
}
|
|
|
|
|
|
|
|
|
|
| 27 |
except Exception as e:
|
|
|
|
| 28 |
raise HTTPException(500, f"Receipt processing error: {str(e)}")
|
|
|
|
| 1 |
+
from fastapi import APIRouter, File, UploadFile, HTTPException, Request, Depends
|
| 2 |
from receipt_processor.google_ocr import GoogleVisionOCR
|
| 3 |
from receipt_processor.parsers.parser_selector import ParserSelector
|
| 4 |
+
from db.receipt_repository import ReceiptRepository
|
| 5 |
+
from utils.rate_limiter import RateLimiter
|
| 6 |
+
import re
|
| 7 |
|
| 8 |
# Initialize OCR and parser selector
|
| 9 |
+
rate_limiter = RateLimiter(max_requests=5)
|
| 10 |
+
receipt_repository = ReceiptRepository()
|
| 11 |
ocr_processor = GoogleVisionOCR()
|
| 12 |
parser_selector = ParserSelector()
|
| 13 |
|
| 14 |
router = APIRouter(prefix="/receipts", tags=["Receipt Processing"])
|
| 15 |
|
| 16 |
@router.post("/scan")
|
| 17 |
+
async def process_receipt(
|
| 18 |
+
request: Request,
|
| 19 |
+
file: UploadFile = File(...),
|
| 20 |
+
client_ip: str = Depends(rate_limiter.check_rate_limit)
|
| 21 |
+
):
|
| 22 |
try:
|
| 23 |
+
print(f"Received file: {file.filename} ({file.content_type})")
|
| 24 |
+
|
| 25 |
+
# Read the file content
|
| 26 |
content = await file.read()
|
| 27 |
+
print(f"File size: {len(content)} bytes")
|
| 28 |
+
|
| 29 |
+
# Extract text using Google OCR
|
| 30 |
extracted_text = ocr_processor.extract_text(content)
|
| 31 |
+
print(f"Extracted text length: {len(extracted_text)} chars")
|
| 32 |
|
| 33 |
if not extracted_text:
|
| 34 |
raise HTTPException(400, "No text extracted from image")
|
| 35 |
|
| 36 |
+
# Select and use appropriate parser
|
| 37 |
parser = parser_selector.get_store_parser(extracted_text)
|
| 38 |
+
print(f"Using parser: {parser.__class__.__name__}")
|
| 39 |
+
|
| 40 |
parsed_receipt = parser.parse(extracted_text)
|
| 41 |
+
print("Parsing completed successfully")
|
| 42 |
+
|
| 43 |
+
# Try to extract date from the receipt
|
| 44 |
+
receipt_date = None
|
| 45 |
+
date_match = re.search(r'(\d{2}[./]\d{2}[./]\d{2,4})', extracted_text)
|
| 46 |
+
if date_match:
|
| 47 |
+
receipt_date = date_match.group(1)
|
| 48 |
+
|
| 49 |
+
# Store the receipt in Supabase
|
| 50 |
+
receipt_repository.create_receipt_request(
|
| 51 |
+
receipt_image=content,
|
| 52 |
+
parsed_data=parsed_receipt,
|
| 53 |
+
receipt_date=receipt_date,
|
| 54 |
+
submission_ip=client_ip
|
| 55 |
+
)
|
| 56 |
|
| 57 |
return {
|
| 58 |
"status": "success",
|
| 59 |
+
"message": "Receipt submitted successfully and pending review."
|
| 60 |
}
|
| 61 |
+
|
| 62 |
+
except HTTPException as e:
|
| 63 |
+
raise e
|
| 64 |
except Exception as e:
|
| 65 |
+
print(f"ERROR: {str(e)}")
|
| 66 |
raise HTTPException(500, f"Receipt processing error: {str(e)}")
|
config/settings.py
CHANGED
|
@@ -1,3 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# Model config
|
| 2 |
MODEL_ONNX_PATH = "product_detector/models/model.onnx"
|
| 3 |
CLASS_NAMES = [
|
|
@@ -33,5 +38,12 @@ INPUT_SIZE = 640
|
|
| 33 |
API_HOST = "0.0.0.0"
|
| 34 |
API_PORT = 7860
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
# Google OCR config
|
| 37 |
GOOGLE_VISION_KEY_PATH = "receipt-vision-key.json"
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dotenv import load_dotenv
|
| 3 |
+
|
| 4 |
+
load_dotenv()
|
| 5 |
+
|
| 6 |
# Model config
|
| 7 |
MODEL_ONNX_PATH = "product_detector/models/model.onnx"
|
| 8 |
CLASS_NAMES = [
|
|
|
|
| 38 |
API_HOST = "0.0.0.0"
|
| 39 |
API_PORT = 7860
|
| 40 |
|
| 41 |
+
# Rate limiting
|
| 42 |
+
MAX_RECEIPTS_PER_HOUR = 5
|
| 43 |
+
|
| 44 |
+
# Supabase
|
| 45 |
+
SUPABASE_URL = os.getenv("SUPABASE_URL")
|
| 46 |
+
SUPABASE_KEY = os.getenv("SUPABASE_KEY")
|
| 47 |
+
|
| 48 |
# Google OCR config
|
| 49 |
GOOGLE_VISION_KEY_PATH = "receipt-vision-key.json"
|
db/receipt_repository.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import uuid
|
| 2 |
+
import json
|
| 3 |
+
import base64
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
from typing import Optional, Dict, Any
|
| 6 |
+
from .supabase_client import SupabaseClient
|
| 7 |
+
|
| 8 |
+
class ReceiptRepository:
|
| 9 |
+
def __init__(self):
|
| 10 |
+
self.supabase = SupabaseClient().get_client()
|
| 11 |
+
|
| 12 |
+
def create_receipt_request(
|
| 13 |
+
self,
|
| 14 |
+
receipt_image: bytes,
|
| 15 |
+
parsed_data: Dict[str, Any],
|
| 16 |
+
receipt_date: Optional[str] = None,
|
| 17 |
+
submission_ip: Optional[str] = None
|
| 18 |
+
) -> Dict[str, Any]:
|
| 19 |
+
"""
|
| 20 |
+
Store a new receipt request in the database
|
| 21 |
+
"""
|
| 22 |
+
base64_image = base64.b64encode(receipt_image).decode('utf-8')
|
| 23 |
+
|
| 24 |
+
receipt_data = {
|
| 25 |
+
"receipt_image": base64_image,
|
| 26 |
+
"parsed_data": json.dumps(parsed_data),
|
| 27 |
+
"receipt_date": receipt_date,
|
| 28 |
+
"request_status": "pending",
|
| 29 |
+
"submission_ip": submission_ip
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
result = self.supabase.table("receipt_requests").insert(receipt_data).execute()
|
| 33 |
+
|
| 34 |
+
if len(result.data) > 0:
|
| 35 |
+
return result.data[0]
|
| 36 |
+
else:
|
| 37 |
+
raise Exception("Failed to create receipt request")
|
| 38 |
+
|
| 39 |
+
def get_submission_count(self, ip_address: str) -> int:
|
| 40 |
+
"""
|
| 41 |
+
Get the number of submissions from a specific IP in the last hour
|
| 42 |
+
"""
|
| 43 |
+
# Get current rate limit record
|
| 44 |
+
result = self.supabase.table("submission_rate_limits") \
|
| 45 |
+
.select("*") \
|
| 46 |
+
.eq("ip_address", ip_address) \
|
| 47 |
+
.execute()
|
| 48 |
+
|
| 49 |
+
# Check if the IP has a record
|
| 50 |
+
if len(result.data) == 0:
|
| 51 |
+
# Create new record
|
| 52 |
+
self.supabase.table("submission_rate_limits").insert({
|
| 53 |
+
"ip_address": ip_address,
|
| 54 |
+
"submission_count": 1,
|
| 55 |
+
"window_start_time": datetime.now().isoformat()
|
| 56 |
+
}).execute()
|
| 57 |
+
return 1
|
| 58 |
+
|
| 59 |
+
# Get the existing record
|
| 60 |
+
record = result.data[0]
|
| 61 |
+
window_start = datetime.fromisoformat(record["window_start_time"])
|
| 62 |
+
current_time = datetime.now()
|
| 63 |
+
|
| 64 |
+
# If the window is older than 1 hour, reset it
|
| 65 |
+
if (current_time - window_start).total_seconds() > 3600:
|
| 66 |
+
self.supabase.table("submission_rate_limits") \
|
| 67 |
+
.update({
|
| 68 |
+
"submission_count": 1,
|
| 69 |
+
"window_start_time": current_time.isoformat()
|
| 70 |
+
}) \
|
| 71 |
+
.eq("ip_address", ip_address) \
|
| 72 |
+
.execute()
|
| 73 |
+
return 1
|
| 74 |
+
|
| 75 |
+
# Increment the counter
|
| 76 |
+
new_count = record["submission_count"] + 1
|
| 77 |
+
self.supabase.table("submission_rate_limits") \
|
| 78 |
+
.update({"submission_count": new_count}) \
|
| 79 |
+
.eq("ip_address", ip_address) \
|
| 80 |
+
.execute()
|
| 81 |
+
|
| 82 |
+
return new_count
|
db/supabase_client.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from supabase import create_client, Client
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
|
| 5 |
+
# Load environment variables
|
| 6 |
+
load_dotenv()
|
| 7 |
+
|
| 8 |
+
class SupabaseClient:
|
| 9 |
+
_instance = None
|
| 10 |
+
|
| 11 |
+
def __new__(cls):
|
| 12 |
+
if cls._instance is None:
|
| 13 |
+
cls._instance = super(SupabaseClient, cls).__new__(cls)
|
| 14 |
+
# Initialize the Supabase client
|
| 15 |
+
url = os.getenv("SUPABASE_URL")
|
| 16 |
+
key = os.getenv("SUPABASE_KEY")
|
| 17 |
+
cls._instance.client = create_client(url, key)
|
| 18 |
+
return cls._instance
|
| 19 |
+
|
| 20 |
+
def get_client(self) -> Client:
|
| 21 |
+
return self.client
|
requirements.txt
CHANGED
|
@@ -7,4 +7,6 @@ Pillow
|
|
| 7 |
torch
|
| 8 |
ultralytics
|
| 9 |
python-multipart
|
| 10 |
-
google-cloud-vision
|
|
|
|
|
|
|
|
|
| 7 |
torch
|
| 8 |
ultralytics
|
| 9 |
python-multipart
|
| 10 |
+
google-cloud-vision
|
| 11 |
+
python-dotenv
|
| 12 |
+
supabase
|
utils/rate_limiter.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import Request, HTTPException
|
| 2 |
+
from db.receipt_repository import ReceiptRepository
|
| 3 |
+
|
| 4 |
+
class RateLimiter:
|
| 5 |
+
def __init__(self, max_requests: int = 5):
|
| 6 |
+
self.max_requests = max_requests
|
| 7 |
+
self.repository = ReceiptRepository()
|
| 8 |
+
|
| 9 |
+
async def check_rate_limit(self, request: Request):
|
| 10 |
+
"""
|
| 11 |
+
Check if the client has exceeded the rate limit
|
| 12 |
+
"""
|
| 13 |
+
client_ip = request.client.host
|
| 14 |
+
|
| 15 |
+
# Get the current submission count for this IP
|
| 16 |
+
submission_count = self.repository.get_submission_count(client_ip)
|
| 17 |
+
|
| 18 |
+
# If they've exceeded the limit, raise an exception
|
| 19 |
+
if submission_count > self.max_requests:
|
| 20 |
+
raise HTTPException(
|
| 21 |
+
status_code=429,
|
| 22 |
+
detail=f"Rate limit exceeded. Maximum {self.max_requests} receipts per hour."
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
return client_ip
|