vision / app.py
asyoucansee's picture
Create app.py
4e1dbc2 verified
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import torch
import os
import io
import base64
import json
import re
from PIL import Image
from transformers import AutoProcessor, AutoModelForImageTextToText
# ── CPU optimization ──────────────────────────────────────────────────────────
NUM_CORES = os.cpu_count() or 2
torch.set_num_threads(NUM_CORES)
torch.set_num_interop_threads(NUM_CORES)
os.environ["OMP_NUM_THREADS"] = str(NUM_CORES)
os.environ["MKL_NUM_THREADS"] = str(NUM_CORES)
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
os.environ["HF_HOME"] = "/tmp/hf_cache"
print(f"CPU cores: {NUM_CORES}")
app = FastAPI(title="Receipt & Invoice Reader API", version="1.0.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# ── Load Qwen2.5-VL-3B ────────────────────────────────────────────────────────
MODEL_NAME = "Qwen/Qwen2.5-VL-3B-Instruct"
print(f"Loading {MODEL_NAME}...")
processor = AutoProcessor.from_pretrained(MODEL_NAME)
model = AutoModelForImageTextToText.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float32,
)
model.eval()
print("Model ready!")
# ── Helpers ───────────────────────────────────────────────────────────────────
EXTRACT_PROMPT = """Look at this receipt or invoice image carefully.
Extract ALL information and return ONLY a JSON object with these fields:
{
"vendor": "store or restaurant name",
"date": "date of purchase (YYYY-MM-DD format if possible)",
"time": "time if visible",
"total": 0.00,
"subtotal": 0.00,
"tax": 0.00,
"discount": 0.00,
"currency": "currency symbol or code",
"payment_method": "cash/card/upi etc",
"items": [
{"name": "item name", "quantity": 1, "price": 0.00}
],
"receipt_number": "receipt or invoice number if visible",
"address": "vendor address if visible"
}
Return ONLY the JSON, no explanation. Use null for fields not found."""
def image_to_base64(image: Image.Image) -> str:
buffer = io.BytesIO()
image.save(buffer, format="JPEG", quality=85)
return base64.b64encode(buffer.getvalue()).decode()
def process_image(image: Image.Image) -> dict:
# Resize if too large β€” saves memory and speeds up inference
max_size = 1024
if max(image.size) > max_size:
image.thumbnail((max_size, max_size), Image.LANCZOS)
# Convert to RGB
if image.mode != "RGB":
image = image.convert("RGB")
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": EXTRACT_PROMPT}
]
}
]
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = processor(
text=[text],
images=[image],
return_tensors="pt",
padding=True
)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=600,
do_sample=False,
temperature=1.0,
repetition_penalty=1.05,
pad_token_id=processor.tokenizer.eos_token_id,
)
new_tokens = outputs[0][inputs["input_ids"].shape[1]:]
response = processor.tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
# Parse JSON from response
try:
# Try direct parse
result = json.loads(response)
except Exception:
# Extract JSON block if wrapped in markdown
match = re.search(r'\{[\s\S]*\}', response)
if match:
try:
result = json.loads(match.group(0))
except Exception:
result = {"raw_response": response, "parse_error": "Could not parse JSON"}
else:
result = {"raw_response": response, "parse_error": "No JSON found in response"}
return result
# ── Routes ────────────────────────────────────────────────────────────────────
@app.get("/")
def root():
return {
"service": "Receipt & Invoice Reader API",
"status": "running",
"model": MODEL_NAME,
"endpoints": {
"POST /extract/upload": "Upload image file β†’ structured JSON",
"POST /extract/base64": "Send base64 image β†’ structured JSON",
"GET /health": "Health check"
}
}
@app.get("/health")
def health():
return {"status": "ok", "model": MODEL_NAME, "cpu_cores": NUM_CORES}
@app.post("/extract/upload")
async def extract_upload(file: UploadFile = File(...)):
# Validate file type
allowed = ["image/jpeg", "image/jpg", "image/png", "image/webp"]
if file.content_type not in allowed:
raise HTTPException(status_code=400, detail="Only JPEG, PNG, WEBP images accepted")
contents = await file.read()
if len(contents) > 10 * 1024 * 1024:
raise HTTPException(status_code=400, detail="File too large (max 10MB)")
try:
image = Image.open(io.BytesIO(contents))
result = process_image(image)
return {"success": True, "data": result}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
class Base64Input(BaseModel):
image: str # base64 encoded image
@app.post("/extract/base64")
async def extract_base64(body: Base64Input):
if not body.image.strip():
raise HTTPException(status_code=400, detail="Image cannot be empty")
try:
# Strip data URL prefix if present
img_data = body.image
if "," in img_data:
img_data = img_data.split(",")[1]
image_bytes = base64.b64decode(img_data)
image = Image.open(io.BytesIO(image_bytes))
result = process_image(image)
return {"success": True, "data": result}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))