Spaces:
Runtime error
Runtime error
| """ | |
| LLaVA-OneVision Service for Invoice Data Extraction | |
| Hugging Face Space compatible FastAPI application | |
| """ | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| import base64 | |
| import io | |
| import json | |
| import os | |
| from typing import Dict, Any, Optional | |
| from PIL import Image | |
| # Import transformers and model utilities | |
| from transformers import AutoTokenizer, AutoProcessor, AutoModelForCausalLM | |
| from qwen_vl_utils import process_vision_info | |
| # Import prompts | |
| from prompts import get_invoice_extraction_prompt | |
| # Initialize FastAPI | |
| app = FastAPI( | |
| title="LLaVA Invoice Extraction Service", | |
| description="Vision-Language Model for Invoice Data Extraction", | |
| version="1.0.0" | |
| ) | |
| # Global model variables (loaded on startup) | |
| model = None | |
| processor = None | |
| # Model configuration | |
| MODEL_NAME = os.getenv("MODEL_NAME", "lmms-lab/LLaVA-OneVision-1.5-4B-Instruct") | |
| MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "2048")) | |
| class InvoiceRequest(BaseModel): | |
| """Request model for invoice extraction""" | |
| image: str # Base64 encoded image | |
| vendor_id: str = "Default" | |
| use_validation: bool = False # Optional: run validation pass | |
| class InvoiceResponse(BaseModel): | |
| """Response model for invoice extraction""" | |
| status: str | |
| data: Optional[Dict[str, Any]] = None | |
| raw_output: Optional[str] = None | |
| error: Optional[str] = None | |
| async def load_model(): | |
| """Load the LLaVA model on startup""" | |
| global model, processor | |
| try: | |
| print(f"Loading model: {MODEL_NAME}...") | |
| # Load model with optimizations for inference | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| torch_dtype="auto", | |
| device_map="auto", | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True | |
| ) | |
| # Load processor | |
| processor = AutoProcessor.from_pretrained( | |
| MODEL_NAME, | |
| trust_remote_code=True | |
| ) | |
| print(f"✓ Model loaded successfully: {MODEL_NAME}") | |
| except Exception as e: | |
| print(f"ERROR loading model: {str(e)}") | |
| raise | |
| def health_check(): | |
| """Health check endpoint""" | |
| return { | |
| "status": "ok", | |
| "service": "LLaVA Invoice Extraction", | |
| "model": MODEL_NAME, | |
| "model_loaded": model is not None | |
| } | |
| async def extract_invoice(request: InvoiceRequest): | |
| """ | |
| Extract invoice data from image using LLaVA-OneVision | |
| Args: | |
| request: InvoiceRequest with base64 image and vendor_id | |
| Returns: | |
| InvoiceResponse with extracted data | |
| """ | |
| global model, processor | |
| if model is None or processor is None: | |
| raise HTTPException( | |
| status_code=503, | |
| detail="Model not loaded. Please wait for startup to complete." | |
| ) | |
| try: | |
| # Decode base64 image | |
| image_data = base64.b64decode(request.image) | |
| image = Image.open(io.BytesIO(image_data)) | |
| # Convert to RGB if needed | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| print(f"Processing invoice for vendor: {request.vendor_id}") | |
| print(f"Image size: {image.size}") | |
| # Get appropriate prompt for vendor | |
| prompt_text = get_invoice_extraction_prompt(request.vendor_id) | |
| # Prepare messages for the model | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": prompt_text} | |
| ] | |
| } | |
| ] | |
| # Prepare inputs for inference | |
| text = processor.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| image_inputs, video_inputs = process_vision_info(messages) | |
| inputs = processor( | |
| text=[text], | |
| images=image_inputs, | |
| videos=video_inputs, | |
| padding=True, | |
| return_tensors="pt" | |
| ) | |
| # Move inputs to same device as model | |
| inputs = inputs.to(model.device) | |
| print("Running inference...") | |
| # Generate response | |
| generated_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=MAX_NEW_TOKENS, | |
| do_sample=False, # Deterministic for consistent extraction | |
| temperature=None, | |
| top_p=None | |
| ) | |
| # Decode output | |
| generated_ids_trimmed = [ | |
| out_ids[len(in_ids):] | |
| for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
| ] | |
| output_text = processor.batch_decode( | |
| generated_ids_trimmed, | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=False | |
| )[0] | |
| print(f"Raw output length: {len(output_text)} chars") | |
| # Parse JSON from output | |
| invoice_data = parse_json_from_output(output_text) | |
| if invoice_data: | |
| print(f"✓ Successfully extracted invoice data") | |
| print(f" Items found: {len(invoice_data.get('items', []))}") | |
| print(f" Total: ${invoice_data.get('total', 0)}") | |
| return InvoiceResponse( | |
| status="success", | |
| data=invoice_data, | |
| raw_output=output_text | |
| ) | |
| else: | |
| print(f"⚠ Could not parse JSON from output") | |
| return InvoiceResponse( | |
| status="error", | |
| error="Failed to parse JSON from model output", | |
| raw_output=output_text | |
| ) | |
| except Exception as e: | |
| print(f"ERROR in extract_invoice: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| return InvoiceResponse( | |
| status="error", | |
| error=str(e) | |
| ) | |
| def parse_json_from_output(output: str) -> Optional[Dict]: | |
| """ | |
| Extract and parse JSON from model output | |
| The model might include extra text, so we need to find the JSON part | |
| Args: | |
| output: Raw output from model | |
| Returns: | |
| Parsed JSON dict or None | |
| """ | |
| try: | |
| # Try direct JSON parse first | |
| return json.loads(output) | |
| except json.JSONDecodeError: | |
| pass | |
| # Try to find JSON in the output | |
| # Look for content between { and } | |
| start_idx = output.find('{') | |
| if start_idx == -1: | |
| return None | |
| # Find matching closing brace | |
| brace_count = 0 | |
| for i in range(start_idx, len(output)): | |
| if output[i] == '{': | |
| brace_count += 1 | |
| elif output[i] == '}': | |
| brace_count -= 1 | |
| if brace_count == 0: | |
| # Found complete JSON | |
| json_str = output[start_idx:i+1] | |
| try: | |
| return json.loads(json_str) | |
| except json.JSONDecodeError: | |
| continue | |
| # If no valid JSON found, try with markdown code blocks | |
| if "```json" in output: | |
| parts = output.split("```json") | |
| if len(parts) > 1: | |
| json_part = parts[1].split("```")[0].strip() | |
| try: | |
| return json.loads(json_part) | |
| except json.JSONDecodeError: | |
| pass | |
| elif "```" in output: | |
| parts = output.split("```") | |
| if len(parts) >= 3: | |
| json_part = parts[1].strip() | |
| try: | |
| return json.loads(json_part) | |
| except json.JSONDecodeError: | |
| pass | |
| return None | |
| def health(): | |
| """Health endpoint for monitoring""" | |
| return {"status": "healthy"} | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.getenv("PORT", "7860")) | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |