sgonzalezu's picture
Deploy LLaVA invoice extraction service
2c911a0
"""
LLaVA-OneVision Service for Invoice Data Extraction
Hugging Face Space compatible FastAPI application
"""
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import base64
import io
import json
import os
from typing import Dict, Any, Optional
from PIL import Image
# Import transformers and model utilities
from transformers import AutoTokenizer, AutoProcessor, AutoModelForCausalLM
from qwen_vl_utils import process_vision_info
# Import prompts
from prompts import get_invoice_extraction_prompt
# Initialize FastAPI
app = FastAPI(
title="LLaVA Invoice Extraction Service",
description="Vision-Language Model for Invoice Data Extraction",
version="1.0.0"
)
# Global model variables (loaded on startup)
model = None
processor = None
# Model configuration
MODEL_NAME = os.getenv("MODEL_NAME", "lmms-lab/LLaVA-OneVision-1.5-4B-Instruct")
MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "2048"))
class InvoiceRequest(BaseModel):
"""Request model for invoice extraction"""
image: str # Base64 encoded image
vendor_id: str = "Default"
use_validation: bool = False # Optional: run validation pass
class InvoiceResponse(BaseModel):
"""Response model for invoice extraction"""
status: str
data: Optional[Dict[str, Any]] = None
raw_output: Optional[str] = None
error: Optional[str] = None
@app.on_event("startup")
async def load_model():
"""Load the LLaVA model on startup"""
global model, processor
try:
print(f"Loading model: {MODEL_NAME}...")
# Load model with optimizations for inference
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype="auto",
device_map="auto",
trust_remote_code=True,
low_cpu_mem_usage=True
)
# Load processor
processor = AutoProcessor.from_pretrained(
MODEL_NAME,
trust_remote_code=True
)
print(f"✓ Model loaded successfully: {MODEL_NAME}")
except Exception as e:
print(f"ERROR loading model: {str(e)}")
raise
@app.get("/")
def health_check():
"""Health check endpoint"""
return {
"status": "ok",
"service": "LLaVA Invoice Extraction",
"model": MODEL_NAME,
"model_loaded": model is not None
}
@app.post("/extract_invoice", response_model=InvoiceResponse)
async def extract_invoice(request: InvoiceRequest):
"""
Extract invoice data from image using LLaVA-OneVision
Args:
request: InvoiceRequest with base64 image and vendor_id
Returns:
InvoiceResponse with extracted data
"""
global model, processor
if model is None or processor is None:
raise HTTPException(
status_code=503,
detail="Model not loaded. Please wait for startup to complete."
)
try:
# Decode base64 image
image_data = base64.b64decode(request.image)
image = Image.open(io.BytesIO(image_data))
# Convert to RGB if needed
if image.mode != "RGB":
image = image.convert("RGB")
print(f"Processing invoice for vendor: {request.vendor_id}")
print(f"Image size: {image.size}")
# Get appropriate prompt for vendor
prompt_text = get_invoice_extraction_prompt(request.vendor_id)
# Prepare messages for the model
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt_text}
]
}
]
# Prepare inputs for inference
text = processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt"
)
# Move inputs to same device as model
inputs = inputs.to(model.device)
print("Running inference...")
# Generate response
generated_ids = model.generate(
**inputs,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=False, # Deterministic for consistent extraction
temperature=None,
top_p=None
)
# Decode output
generated_ids_trimmed = [
out_ids[len(in_ids):]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)[0]
print(f"Raw output length: {len(output_text)} chars")
# Parse JSON from output
invoice_data = parse_json_from_output(output_text)
if invoice_data:
print(f"✓ Successfully extracted invoice data")
print(f" Items found: {len(invoice_data.get('items', []))}")
print(f" Total: ${invoice_data.get('total', 0)}")
return InvoiceResponse(
status="success",
data=invoice_data,
raw_output=output_text
)
else:
print(f"⚠ Could not parse JSON from output")
return InvoiceResponse(
status="error",
error="Failed to parse JSON from model output",
raw_output=output_text
)
except Exception as e:
print(f"ERROR in extract_invoice: {str(e)}")
import traceback
traceback.print_exc()
return InvoiceResponse(
status="error",
error=str(e)
)
def parse_json_from_output(output: str) -> Optional[Dict]:
"""
Extract and parse JSON from model output
The model might include extra text, so we need to find the JSON part
Args:
output: Raw output from model
Returns:
Parsed JSON dict or None
"""
try:
# Try direct JSON parse first
return json.loads(output)
except json.JSONDecodeError:
pass
# Try to find JSON in the output
# Look for content between { and }
start_idx = output.find('{')
if start_idx == -1:
return None
# Find matching closing brace
brace_count = 0
for i in range(start_idx, len(output)):
if output[i] == '{':
brace_count += 1
elif output[i] == '}':
brace_count -= 1
if brace_count == 0:
# Found complete JSON
json_str = output[start_idx:i+1]
try:
return json.loads(json_str)
except json.JSONDecodeError:
continue
# If no valid JSON found, try with markdown code blocks
if "```json" in output:
parts = output.split("```json")
if len(parts) > 1:
json_part = parts[1].split("```")[0].strip()
try:
return json.loads(json_part)
except json.JSONDecodeError:
pass
elif "```" in output:
parts = output.split("```")
if len(parts) >= 3:
json_part = parts[1].strip()
try:
return json.loads(json_part)
except json.JSONDecodeError:
pass
return None
@app.get("/health")
def health():
"""Health endpoint for monitoring"""
return {"status": "healthy"}
if __name__ == "__main__":
import uvicorn
port = int(os.getenv("PORT", "7860"))
uvicorn.run(app, host="0.0.0.0", port=port)