Spaces:
Running
Running
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from typing import List | |
| import requests | |
| import base64 | |
| import json | |
| import re | |
| import io | |
| import tempfile | |
| import os | |
| app = FastAPI( | |
| title="Invoice OCR API", | |
| description="Two-step pipeline: nemoretriever-ocr-v1 β nvidia-nemotron-nano-9b-v2 for Tax Invoice extraction. Supports images AND multi-page PDFs.", | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ββ Configuration βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| NVIDIA_API_KEY = "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" | |
| OCR_URL = "https://ai.api.nvidia.com/v1/cv/nvidia/nemoretriever-ocr-v1" | |
| LLM_URL = "https://integrate.api.nvidia.com/v1/chat/completions" | |
| LLM_MODEL = "nvidia/nvidia-nemotron-nano-9b-v2" | |
| OCR_HEADERS = {"Authorization": f"Bearer {NVIDIA_API_KEY}", "Accept": "application/json"} | |
| LLM_HEADERS = {"Authorization": f"Bearer {NVIDIA_API_KEY}", "Content-Type": "application/json"} | |
| # ββ System prompt βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| INVOICE_SYSTEM_PROMPT = """You are a Tax Invoice data extraction assistant for Indian GST invoices. | |
| You will receive OCR text from a tax invoice image. Return ONLY a valid JSON object. No markdown, no explanation. | |
| JSON schema (return exactly this): | |
| { | |
| "invoice_number": "invoice number e.g. ACMPL/01/19-20 (string)", | |
| "eway_bill_number": "e-Way Bill No if present (string)", | |
| "invoice_date": "date e.g. 18-Apr-2019 (string)", | |
| "mode_of_payment": "cash/credit/UPI/bank transfer etc (string)", | |
| "supplier_ref": "supplier reference number (string)", | |
| "buyer_order_number": "buyer's PO number (string)", | |
| "dispatch_document_number": "dispatch doc number (string)", | |
| "dispatched_through": "courier or transport name (string)", | |
| "destination": "delivery destination (string)", | |
| "delivery_note": "delivery note number (string)", | |
| "vendor_name": "name of selling company e.g. Ace Mobile Manufacturer Pvt Ltd (string)", | |
| "vendor_address": "full address of vendor (string)", | |
| "vendor_gstin": "15-char GSTIN of vendor e.g. 09AABCS1429B1ZS (string)", | |
| "vendor_state": "state name and code e.g. Uttar Pradesh, Code: 09 (string)", | |
| "vendor_email": "email of vendor (string)", | |
| "buyer_name": "name of buyer/customer e.g. The Mobile Planet (string)", | |
| "buyer_address": "full address of buyer (string)", | |
| "buyer_gstin": "15-char GSTIN of buyer e.g. 09AAGCA1654H1ZQ (string)", | |
| "buyer_state": "state name and code of buyer (string)", | |
| "items": [ | |
| { | |
| "sl_no": "serial number (string)", | |
| "description": "description of goods (string)", | |
| "batch": "batch number if present (string)", | |
| "hsn_sac": "HSN or SAC code (string)", | |
| "quantity": "quantity with unit e.g. 500 Nos (string)", | |
| "rate": "rate per unit e.g. 6000.00 (string)", | |
| "per": "unit type e.g. Nos (string)", | |
| "amount": "line total e.g. 30,00,000.00 (string)" | |
| } | |
| ], | |
| "taxable_value": "total taxable amount before tax (string)", | |
| "cgst_rate": "CGST rate percentage e.g. 6% (string)", | |
| "cgst_amount": "CGST amount (string)", | |
| "sgst_rate": "SGST rate percentage e.g. 6% (string)", | |
| "sgst_amount": "SGST amount (string)", | |
| "igst_rate": "IGST rate if applicable (string)", | |
| "igst_amount": "IGST amount if applicable (string)", | |
| "output_cgst": "Output CGST amount (string)", | |
| "output_sgst": "Output SGST amount (string)", | |
| "total_tax_amount": "total tax amount (string)", | |
| "grand_total": "final invoice total e.g. 96,32,000.00 (string)", | |
| "amount_in_words": "amount in words e.g. INR Ninety Six Lakh Thirty Two Thousand Only (string)", | |
| "tax_amount_in_words": "tax amount in words (string)", | |
| "hsn_summary": [ | |
| { | |
| "hsn_sac": "HSN code (string)", | |
| "taxable_value": "taxable value for this HSN (string)", | |
| "cgst_rate": "CGST rate (string)", | |
| "cgst_amount": "CGST amount (string)", | |
| "sgst_rate": "SGST rate (string)", | |
| "sgst_amount": "SGST amount (string)", | |
| "total_tax": "total tax for this HSN (string)" | |
| } | |
| ], | |
| "declaration": "declaration text at bottom (string)", | |
| "authorised_signatory": "authorised signatory label (string)", | |
| "is_computer_generated": true | |
| } | |
| CRITICAL RULES: | |
| - invoice_number: look for Invoice No., Bill No., Ref No. near the top right area | |
| - vendor_name: the company at the TOP of the invoice, usually with logo | |
| - buyer_name: look for 'Buyer', 'Bill To', 'Sold To' section | |
| - GSTIN: exactly 15 characters, mix of letters and digits e.g. 09AABCS1429B1ZS | |
| - items: extract EVERY line item row in the goods table including batch info | |
| - amounts: keep exact format with commas e.g. 30,00,000.00 | |
| - hsn_summary: extract the tax summary table at the bottom (HSN/SAC wise breakdown) | |
| - output_cgst / output_sgst: look for 'Output CGST' and 'Output SGST' labels in totals | |
| - grand_total: the final TOTAL amount, look for βΉ symbol | |
| - amount_in_words: the spelled-out amount e.g. 'INR Ninety Six Lakh...' | |
| - If a field is not found, use "" for strings, [] for arrays, false for booleans""" | |
| # ββ PDF β images ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def pdf_bytes_to_png_list(pdf_bytes: bytes, dpi: int = 200) -> list[bytes]: | |
| """ | |
| Convert every page of a PDF into a PNG byte string. | |
| Uses pdf2image (which wraps poppler's pdftoppm). | |
| Falls back to pypdfium2 if pdf2image is unavailable. | |
| """ | |
| try: | |
| from pdf2image import convert_from_bytes | |
| pil_images = convert_from_bytes(pdf_bytes, dpi=dpi, fmt="png") | |
| result = [] | |
| for img in pil_images: | |
| buf = io.BytesIO() | |
| img.save(buf, format="PNG") | |
| result.append(buf.getvalue()) | |
| return result | |
| except ImportError: | |
| pass # Try fallback | |
| try: | |
| import pypdfium2 as pdfium | |
| pdf = pdfium.PdfDocument(pdf_bytes) | |
| scale = dpi / 72 # pdftoppm default is 72 dpi | |
| result = [] | |
| for page_index in range(len(pdf)): | |
| page = pdf[page_index] | |
| bitmap = page.render(scale=scale, rotation=0) | |
| pil_image = bitmap.to_pil() | |
| buf = io.BytesIO() | |
| pil_image.save(buf, format="PNG") | |
| result.append(buf.getvalue()) | |
| return result | |
| except ImportError: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=( | |
| "PDF conversion requires pdf2image+poppler or pypdfium2. " | |
| "Install with: pip install pdf2image pypdfium2 " | |
| "and apt-get install poppler-utils" | |
| ), | |
| ) | |
| # ββ OCR helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def extract_all_text_sorted(ocr_json: dict) -> tuple[str, list]: | |
| """Sort OCR detections spatially top-to-bottom, left-to-right in bands.""" | |
| data = ocr_json.get("data", []) | |
| detections = data[0].get("text_detections", []) if data else ocr_json.get("text_detections", []) | |
| items = [] | |
| for det in detections: | |
| if not isinstance(det, dict): | |
| continue | |
| if "text_prediction" in det: | |
| text = det["text_prediction"].get("text", "").strip() | |
| else: | |
| text = det.get("text", "").strip() | |
| if not text: | |
| continue | |
| pts = det.get("bounding_box", {}).get("points", []) | |
| y = sum(p["y"] for p in pts) / len(pts) if pts else 0 | |
| x = min(p["x"] for p in pts) if pts else 0 | |
| items.append({"text": text, "y": y, "x": x}) | |
| BAND = 0.012 | |
| items.sort(key=lambda d: (round(d["y"] / BAND), d["x"])) | |
| full_text = "\n".join(i["text"] for i in items) | |
| return full_text, items | |
| def run_ocr_on_bytes(image_bytes: bytes, page_label: str = "") -> tuple[str, list]: | |
| """Run OCR on raw image bytes. Returns (text, detections).""" | |
| image_b64 = base64.b64encode(image_bytes).decode() | |
| if len(image_b64) >= 1_000_000: | |
| raise HTTPException( | |
| status_code=413, | |
| detail=f"Image too large{' (page ' + page_label + ')' if page_label else ''}. Resize and retry." | |
| ) | |
| payload = {"input": [{"type": "image_url", "url": f"data:image/png;base64,{image_b64}"}]} | |
| try: | |
| resp = requests.post(OCR_URL, headers=OCR_HEADERS, json=payload, timeout=30) | |
| resp.raise_for_status() | |
| except requests.exceptions.RequestException as e: | |
| raise HTTPException(status_code=502, detail=f"NVIDIA OCR error: {str(e)}") | |
| ocr_json = resp.json() | |
| text, items = extract_all_text_sorted(ocr_json) | |
| label = f"page {page_label} " if page_label else "" | |
| print(f"OCR {label}({len(text)} chars):\n{text[:400]}\n{'='*60}") | |
| return text, items | |
| async def run_ocr(file: UploadFile) -> tuple[str, list]: | |
| """ | |
| Read the uploaded file. | |
| - If PDF β convert each page to PNG, OCR all pages, concatenate text. | |
| - If image β OCR directly (original behaviour). | |
| Returns (combined_text, detections_of_first_page). | |
| """ | |
| content = await file.read() | |
| content_type = (file.content_type or "").lower() | |
| filename = (file.filename or "").lower() | |
| is_pdf = content_type == "application/pdf" or filename.endswith(".pdf") | |
| if is_pdf: | |
| print(f"PDF detected ({len(content)} bytes). Converting pages to imagesβ¦") | |
| page_images = pdf_bytes_to_png_list(content) | |
| if not page_images: | |
| raise HTTPException(status_code=422, detail="PDF has no renderable pages.") | |
| all_texts: list[str] = [] | |
| first_detections: list = [] | |
| for i, img_bytes in enumerate(page_images, start=1): | |
| page_text, detections = run_ocr_on_bytes(img_bytes, page_label=str(i)) | |
| if page_text.strip(): | |
| all_texts.append(f"--- Page {i} ---\n{page_text}") | |
| if i == 1: | |
| first_detections = detections | |
| combined = "\n\n".join(all_texts) | |
| print(f"Total combined OCR text: {len(combined)} chars across {len(page_images)} page(s)") | |
| return combined, first_detections | |
| else: | |
| # Original image path | |
| return run_ocr_on_bytes(content) | |
| # ββ LLM βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def call_llm(ocr_text: str) -> dict: | |
| payload = { | |
| "model": LLM_MODEL, | |
| "max_tokens": 5000, | |
| "temperature": 0.1, | |
| "top_p": 0.9, | |
| "messages": [ | |
| {"role": "system", "content": INVOICE_SYSTEM_PROMPT}, | |
| { | |
| "role": "user", | |
| "content": ( | |
| f"OCR text from tax invoice:\n\n{ocr_text}\n\n" | |
| "Return ONLY the complete JSON object." | |
| ), | |
| }, | |
| ], | |
| } | |
| try: | |
| resp = requests.post(LLM_URL, headers=LLM_HEADERS, json=payload, timeout=200) | |
| resp.raise_for_status() | |
| llm_json = resp.json() | |
| except requests.exceptions.RequestException as e: | |
| raise HTTPException(status_code=502, detail=f"NVIDIA LLM error: {str(e)}") | |
| choice = llm_json.get("choices", [{}])[0] | |
| raw = choice.get("message", {}).get("content", "") | |
| finish = choice.get("finish_reason", "") | |
| print(f"LLM finish={finish}\nRaw (first 600):\n{raw[:600]}\n{'='*60}") | |
| if not raw: | |
| raise HTTPException(status_code=502, detail="LLM returned empty response") | |
| cleaned = re.sub(r"```json\s*", "", raw, flags=re.IGNORECASE) | |
| cleaned = re.sub(r"```\s*", "", cleaned).strip() | |
| # Try direct parse | |
| try: | |
| parsed = json.loads(cleaned) | |
| if isinstance(parsed, dict): | |
| return parsed | |
| except json.JSONDecodeError: | |
| pass | |
| # Try extracting largest JSON block | |
| match = re.search(r"\{[\s\S]*\}", cleaned) | |
| if match: | |
| try: | |
| parsed = json.loads(match.group(0)) | |
| if isinstance(parsed, dict): | |
| return parsed | |
| except json.JSONDecodeError: | |
| pass | |
| # Patch truncated JSON | |
| patched = cleaned.rstrip().rstrip(",") | |
| open_braces = patched.count("{") - patched.count("}") | |
| open_brackets = patched.count("[") - patched.count("]") | |
| patched += "]" * max(0, open_brackets) + "}" * max(0, open_braces) | |
| try: | |
| parsed = json.loads(patched) | |
| if isinstance(parsed, dict): | |
| print("WARNING: used bracket-patching to fix truncated JSON") | |
| return parsed | |
| except json.JSONDecodeError: | |
| pass | |
| raise HTTPException( | |
| status_code=502, | |
| detail=f"JSON parse failed (finish={finish}). Preview: {raw[:400]}" | |
| ) | |
| # ββ Pydantic models βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class LineItem(BaseModel): | |
| sl_no: str | |
| description: str | |
| batch: str | |
| hsn_sac: str | |
| quantity: str | |
| rate: str | |
| per: str | |
| amount: str | |
| class HSNSummary(BaseModel): | |
| hsn_sac: str | |
| taxable_value: str | |
| cgst_rate: str | |
| cgst_amount: str | |
| sgst_rate: str | |
| sgst_amount: str | |
| total_tax: str | |
| class InvoiceData(BaseModel): | |
| invoice_number: str | |
| eway_bill_number: str | |
| invoice_date: str | |
| mode_of_payment: str | |
| supplier_ref: str | |
| buyer_order_number: str | |
| dispatch_document_number: str | |
| dispatched_through: str | |
| destination: str | |
| delivery_note: str | |
| vendor_name: str | |
| vendor_address: str | |
| vendor_gstin: str | |
| vendor_state: str | |
| vendor_email: str | |
| buyer_name: str | |
| buyer_address: str | |
| buyer_gstin: str | |
| buyer_state: str | |
| items: List[LineItem] | |
| taxable_value: str | |
| cgst_rate: str | |
| cgst_amount: str | |
| sgst_rate: str | |
| sgst_amount: str | |
| igst_rate: str | |
| igst_amount: str | |
| output_cgst: str | |
| output_sgst: str | |
| total_tax_amount: str | |
| grand_total: str | |
| amount_in_words: str | |
| tax_amount_in_words: str | |
| hsn_summary: List[HSNSummary] | |
| declaration: str | |
| authorised_signatory: str | |
| is_computer_generated: bool | |
| # Extra metadata for PDF uploads | |
| source_pages: int = 1 | |
| # ββ Endpoint ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def extract_invoice(file: UploadFile = File(...)): | |
| """ | |
| Upload a tax invoice image (JPEG/PNG/WEBP) **or a PDF** β | |
| structured JSON with all fields. | |
| For PDFs, every page is rendered to PNG at 200 dpi, OCR'd individually, | |
| and the combined text is sent to the LLM for a single unified extraction. | |
| """ | |
| allowed_images = {"image/jpeg", "image/jpg", "image/png", "image/webp", "image/gif"} | |
| content_type = (file.content_type or "").lower() | |
| filename = (file.filename or "").lower() | |
| is_pdf = content_type == "application/pdf" or filename.endswith(".pdf") | |
| if content_type and content_type not in allowed_images and not is_pdf: | |
| raise HTTPException(status_code=415, detail=f"Unsupported type: {file.content_type}") | |
| ocr_text, _ = await run_ocr(file) | |
| if not ocr_text.strip(): | |
| raise HTTPException(status_code=422, detail="OCR produced no text.") | |
| # Count pages from combined text header markers | |
| page_count = max(1, ocr_text.count("--- Page ")) | |
| parsed = call_llm(ocr_text) | |
| def s(key, n=300): return str(parsed.get(key, "")).strip()[:n] | |
| return InvoiceData( | |
| invoice_number=s("invoice_number", 60), | |
| eway_bill_number=s("eway_bill_number", 30), | |
| invoice_date=s("invoice_date", 30), | |
| mode_of_payment=s("mode_of_payment", 60), | |
| supplier_ref=s("supplier_ref", 60), | |
| buyer_order_number=s("buyer_order_number", 60), | |
| dispatch_document_number=s("dispatch_document_number", 60), | |
| dispatched_through=s("dispatched_through", 100), | |
| destination=s("destination", 100), | |
| delivery_note=s("delivery_note", 60), | |
| vendor_name=s("vendor_name", 150), | |
| vendor_address=s("vendor_address", 300), | |
| vendor_gstin=s("vendor_gstin", 20), | |
| vendor_state=s("vendor_state", 100), | |
| vendor_email=s("vendor_email", 100), | |
| buyer_name=s("buyer_name", 150), | |
| buyer_address=s("buyer_address", 300), | |
| buyer_gstin=s("buyer_gstin", 20), | |
| buyer_state=s("buyer_state", 100), | |
| items=[ | |
| LineItem( | |
| sl_no=str(i.get("sl_no", ""))[:10], | |
| description=str(i.get("description", ""))[:200], | |
| batch=str(i.get("batch", ""))[:50], | |
| hsn_sac=str(i.get("hsn_sac", ""))[:20], | |
| quantity=str(i.get("quantity", ""))[:30], | |
| rate=str(i.get("rate", ""))[:30], | |
| per=str(i.get("per", ""))[:20], | |
| amount=str(i.get("amount", ""))[:30], | |
| ) | |
| for i in parsed.get("items", []) if isinstance(i, dict) | |
| ], | |
| taxable_value=s("taxable_value", 30), | |
| cgst_rate=s("cgst_rate", 10), | |
| cgst_amount=s("cgst_amount", 30), | |
| sgst_rate=s("sgst_rate", 10), | |
| sgst_amount=s("sgst_amount", 30), | |
| igst_rate=s("igst_rate", 10), | |
| igst_amount=s("igst_amount", 30), | |
| output_cgst=s("output_cgst", 30), | |
| output_sgst=s("output_sgst", 30), | |
| total_tax_amount=s("total_tax_amount", 30), | |
| grand_total=s("grand_total", 30), | |
| amount_in_words=s("amount_in_words", 300), | |
| tax_amount_in_words=s("tax_amount_in_words", 300), | |
| hsn_summary=[ | |
| HSNSummary( | |
| hsn_sac=str(h.get("hsn_sac", ""))[:20], | |
| taxable_value=str(h.get("taxable_value", ""))[:30], | |
| cgst_rate=str(h.get("cgst_rate", ""))[:10], | |
| cgst_amount=str(h.get("cgst_amount", ""))[:30], | |
| sgst_rate=str(h.get("sgst_rate", ""))[:10], | |
| sgst_amount=str(h.get("sgst_amount", ""))[:30], | |
| total_tax=str(h.get("total_tax", ""))[:30], | |
| ) | |
| for h in parsed.get("hsn_summary", []) if isinstance(h, dict) | |
| ], | |
| declaration=s("declaration", 500), | |
| authorised_signatory=s("authorised_signatory", 100), | |
| is_computer_generated=bool(parsed.get("is_computer_generated", False)), | |
| source_pages=page_count, | |
| ) | |
| async def health(): | |
| return {"status": "healthy", "model": LLM_MODEL} | |
| async def root(): | |
| return { | |
| "name": "Invoice OCR API", | |
| "version": "3.0.0", | |
| "pipeline": "nemoretriever-ocr-v1 β nvidia-nemotron-nano-9b-v2", | |
| "pdf_support": "Each PDF page is rendered to PNG (200 dpi) before OCR", | |
| "endpoints": { | |
| "POST /extract-invoice": "Upload tax invoice image or PDF β full structured JSON", | |
| "GET /health": "Health check", | |
| "GET /docs": "Swagger UI", | |
| }, | |
| } | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True) | |