Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, UploadFile, File, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import torch | |
| import os | |
| import io | |
| import base64 | |
| import json | |
| import re | |
| from PIL import Image | |
| from transformers import AutoProcessor, AutoModelForImageTextToText | |
| # ββ CPU optimization ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| NUM_CORES = os.cpu_count() or 2 | |
| torch.set_num_threads(NUM_CORES) | |
| torch.set_num_interop_threads(NUM_CORES) | |
| os.environ["OMP_NUM_THREADS"] = str(NUM_CORES) | |
| os.environ["MKL_NUM_THREADS"] = str(NUM_CORES) | |
| os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache" | |
| os.environ["HF_HOME"] = "/tmp/hf_cache" | |
| print(f"CPU cores: {NUM_CORES}") | |
| app = FastAPI(title="Receipt & Invoice Reader API", version="1.0.0") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ββ Load Qwen2.5-VL-3B ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| MODEL_NAME = "Qwen/Qwen2.5-VL-3B-Instruct" | |
| print(f"Loading {MODEL_NAME}...") | |
| processor = AutoProcessor.from_pretrained(MODEL_NAME) | |
| model = AutoModelForImageTextToText.from_pretrained( | |
| MODEL_NAME, | |
| torch_dtype=torch.float32, | |
| ) | |
| model.eval() | |
| print("Model ready!") | |
| # ββ Helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| EXTRACT_PROMPT = """Look at this receipt or invoice image carefully. | |
| Extract ALL information and return ONLY a JSON object with these fields: | |
| { | |
| "vendor": "store or restaurant name", | |
| "date": "date of purchase (YYYY-MM-DD format if possible)", | |
| "time": "time if visible", | |
| "total": 0.00, | |
| "subtotal": 0.00, | |
| "tax": 0.00, | |
| "discount": 0.00, | |
| "currency": "currency symbol or code", | |
| "payment_method": "cash/card/upi etc", | |
| "items": [ | |
| {"name": "item name", "quantity": 1, "price": 0.00} | |
| ], | |
| "receipt_number": "receipt or invoice number if visible", | |
| "address": "vendor address if visible" | |
| } | |
| Return ONLY the JSON, no explanation. Use null for fields not found.""" | |
| def image_to_base64(image: Image.Image) -> str: | |
| buffer = io.BytesIO() | |
| image.save(buffer, format="JPEG", quality=85) | |
| return base64.b64encode(buffer.getvalue()).decode() | |
| def process_image(image: Image.Image) -> dict: | |
| # Resize if too large β saves memory and speeds up inference | |
| max_size = 1024 | |
| if max(image.size) > max_size: | |
| image.thumbnail((max_size, max_size), Image.LANCZOS) | |
| # Convert to RGB | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": EXTRACT_PROMPT} | |
| ] | |
| } | |
| ] | |
| text = processor.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| inputs = processor( | |
| text=[text], | |
| images=[image], | |
| return_tensors="pt", | |
| padding=True | |
| ) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=600, | |
| do_sample=False, | |
| temperature=1.0, | |
| repetition_penalty=1.05, | |
| pad_token_id=processor.tokenizer.eos_token_id, | |
| ) | |
| new_tokens = outputs[0][inputs["input_ids"].shape[1]:] | |
| response = processor.tokenizer.decode(new_tokens, skip_special_tokens=True).strip() | |
| # Parse JSON from response | |
| try: | |
| # Try direct parse | |
| result = json.loads(response) | |
| except Exception: | |
| # Extract JSON block if wrapped in markdown | |
| match = re.search(r'\{[\s\S]*\}', response) | |
| if match: | |
| try: | |
| result = json.loads(match.group(0)) | |
| except Exception: | |
| result = {"raw_response": response, "parse_error": "Could not parse JSON"} | |
| else: | |
| result = {"raw_response": response, "parse_error": "No JSON found in response"} | |
| return result | |
| # ββ Routes ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def root(): | |
| return { | |
| "service": "Receipt & Invoice Reader API", | |
| "status": "running", | |
| "model": MODEL_NAME, | |
| "endpoints": { | |
| "POST /extract/upload": "Upload image file β structured JSON", | |
| "POST /extract/base64": "Send base64 image β structured JSON", | |
| "GET /health": "Health check" | |
| } | |
| } | |
| def health(): | |
| return {"status": "ok", "model": MODEL_NAME, "cpu_cores": NUM_CORES} | |
| async def extract_upload(file: UploadFile = File(...)): | |
| # Validate file type | |
| allowed = ["image/jpeg", "image/jpg", "image/png", "image/webp"] | |
| if file.content_type not in allowed: | |
| raise HTTPException(status_code=400, detail="Only JPEG, PNG, WEBP images accepted") | |
| contents = await file.read() | |
| if len(contents) > 10 * 1024 * 1024: | |
| raise HTTPException(status_code=400, detail="File too large (max 10MB)") | |
| try: | |
| image = Image.open(io.BytesIO(contents)) | |
| result = process_image(image) | |
| return {"success": True, "data": result} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| class Base64Input(BaseModel): | |
| image: str # base64 encoded image | |
| async def extract_base64(body: Base64Input): | |
| if not body.image.strip(): | |
| raise HTTPException(status_code=400, detail="Image cannot be empty") | |
| try: | |
| # Strip data URL prefix if present | |
| img_data = body.image | |
| if "," in img_data: | |
| img_data = img_data.split(",")[1] | |
| image_bytes = base64.b64decode(img_data) | |
| image = Image.open(io.BytesIO(image_bytes)) | |
| result = process_image(image) | |
| return {"success": True, "data": result} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) |