| """ |
| Inference Processor - Handles VLM extraction, validation, and result formatting |
| """ |
|
|
| import torch |
| import time |
| import json |
| import codecs |
| import re |
| from PIL import Image |
| from qwen_vl_utils import process_vision_info |
| from typing import Dict, Tuple |
|
|
| from config import ( |
| MAX_IMAGE_SIZE, |
| HP_VALID_RANGE, |
| ASSET_COST_VALID_RANGE, |
| COST_PER_GPU_HOUR |
| ) |
| from model_manager import model_manager |
|
|
|
|
| EXTRACTION_PROMPT = """ |
| You are an expert at reading noisy, handwritten Indian invoices and quotations. |
| |
| Your task is to extract text EXACTLY as it appears in the image. |
| Do NOT translate, summarize, normalize, or rewrite any text. |
| Preserve the original language (Hindi, Marathi, Kannada, English, etc.). |
| |
| Carefully read the image and extract the following fields. |
| |
| Return ONLY valid JSON in this format: |
| |
| { |
| "dealer_name": string, |
| "model_name": string, |
| "horse_power": number, |
| "asset_cost": number |
| } |
| |
| Critical rules: |
| - Dealer name must be copied exactly from the image in the original language and spelling. |
| - Model name must be copied exactly from the image without translation. |
| - Do NOT convert regional language text into English. |
| - Do NOT expand abbreviations or correct spelling. |
| - Only numbers may be normalized. |
| |
| Extraction hints: |
| - Asset cost is the total amount, usually the largest number on the page, the total amount after TAX, final price or final cost. |
| - Dealer name is usually at the top header or company name. |
| - Model name often appears near words like Model, Tractor, Variant. |
| - Horse power must come ONLY from explicit HP text, never from model numbers. |
| - Horse power may appear as "HP", handwritten like "49 HP", "63hp", "HP-30". |
| - Remove commas and currency symbols from numbers only. |
| - If handwriting is unclear, make your best reasonable interpretation of the characters — but preserve language. |
| |
| Output rules: |
| - Output ONLY valid JSON. |
| - Do NOT include markdown, explanations, or extra text. |
| """ |
|
|
|
|
| class InferenceProcessor: |
| """Handles VLM inference, validation, and result processing""" |
| |
| @staticmethod |
| def preprocess_image(image_path: str) -> Image.Image: |
| """Load and resize image if needed""" |
| image = Image.open(image_path).convert("RGB") |
| |
| |
| if max(image.size) > MAX_IMAGE_SIZE: |
| ratio = MAX_IMAGE_SIZE / max(image.size) |
| new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio)) |
| image = image.resize(new_size, Image.LANCZOS) |
| print(f"🔄 Image resized to {new_size}") |
| |
| return image |
| |
| @staticmethod |
| def run_vlm_extraction(image: Image.Image) -> Tuple[str, float]: |
| """Run VLM model to extract invoice fields""" |
| if not model_manager.is_loaded(): |
| raise RuntimeError("Models not loaded") |
| |
| model = model_manager.vlm_model |
| processor = model_manager.processor |
| |
| messages = [ |
| { |
| "role": "user", |
| "content": [ |
| {"type": "image", "image": image}, |
| {"type": "text", "text": EXTRACTION_PROMPT} |
| ] |
| } |
| ] |
| |
| |
| text = processor.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True |
| ) |
| |
| |
| image_inputs, video_inputs = process_vision_info(messages) |
| inputs = processor( |
| text=[text], |
| images=image_inputs, |
| videos=video_inputs, |
| padding=True, |
| return_tensors="pt", |
| ) |
| inputs = inputs.to("cuda") |
| |
| start = time.time() |
| |
| |
| generated_ids = model.generate(**inputs, max_new_tokens=256) |
| |
| latency = time.time() - start |
| |
| |
| generated_ids_trimmed = [ |
| out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) |
| ] |
| output_text = processor.batch_decode( |
| generated_ids_trimmed, |
| skip_special_tokens=True, |
| clean_up_tokenization_spaces=False |
| ) |
| |
| output_text = output_text[0] if isinstance(output_text, list) else output_text |
| |
| |
| del inputs, generated_ids, generated_ids_trimmed |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| |
| return output_text, latency |
| |
| @staticmethod |
| def extract_json_from_output(text: str) -> Dict: |
| """Extract JSON from model output""" |
| |
| if text.count('```') in [1, 2]: |
| data = text.split('```')[1] |
| if data.startswith('json'): |
| data = data[4:] |
| try: |
| return json.loads(data.strip()) |
| except: |
| pass |
| |
| |
| markdown_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', text, re.DOTALL) |
| if markdown_match: |
| try: |
| return json.loads(markdown_match.group(1)) |
| except json.JSONDecodeError: |
| pass |
| |
| |
| json_matches = re.finditer(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', text, re.DOTALL) |
| |
| for match in json_matches: |
| json_str = match.group(0) |
| try: |
| parsed = json.loads(json_str) |
| |
| if all(key in parsed for key in ["dealer_name", "model_name", "horse_power", "asset_cost"]): |
| return parsed |
| except json.JSONDecodeError: |
| continue |
| |
| |
| return { |
| "dealer_name": None, |
| "model_name": None, |
| "horse_power": None, |
| "asset_cost": None |
| } |
| |
| @staticmethod |
| def clean_text(text) -> str: |
| """Clean text field""" |
| if not text: |
| return None |
| text = str(text).strip() |
| text = re.sub(r"\s+", " ", text) |
| return text if len(text) > 1 else None |
| |
| @staticmethod |
| def clean_number(num): |
| """Clean number field""" |
| try: |
| if num is None: |
| return None |
| return int(float(num)) |
| except: |
| return None |
| |
| @staticmethod |
| def fix_horse_power(vlm_hp, model_name) -> Tuple: |
| """Fix common HP extraction mistakes""" |
| |
| if vlm_hp is not None and HP_VALID_RANGE[0] <= vlm_hp <= HP_VALID_RANGE[1]: |
| return vlm_hp, 1.0 |
| |
| |
| if model_name: |
| match = re.search(r"HP[- ]?(\d+)", model_name, re.I) |
| if match: |
| hp = int(match.group(1)) |
| if HP_VALID_RANGE[0] <= hp <= HP_VALID_RANGE[1]: |
| return hp, 0.8 |
| |
| return None, 0.2 |
| |
| @staticmethod |
| def validate_asset_cost(cost) -> Tuple: |
| """Validate asset cost""" |
| if cost is None: |
| return None, 0.2 |
| |
| cost = InferenceProcessor.clean_number(cost) |
| |
| if ASSET_COST_VALID_RANGE[0] <= cost <= ASSET_COST_VALID_RANGE[1]: |
| return cost, 1.0 |
| |
| return None, 0.3 |
| |
| @staticmethod |
| def validate_text_field(text) -> Tuple: |
| """Validate text fields""" |
| text = InferenceProcessor.clean_text(text) |
| if not text or len(text) < 3: |
| return None, 0.3 |
| return text, 1.0 |
| |
| @staticmethod |
| def validate_prediction(raw_json: Dict) -> Tuple[Dict, float, list]: |
| """Validate and fix extracted fields""" |
| warnings = [] |
| confidences = [] |
| |
| |
| dealer, dealer_conf = InferenceProcessor.validate_text_field(raw_json.get("dealer_name")) |
| if dealer is None: |
| warnings.append("Dealer name invalid") |
| confidences.append(dealer_conf) |
| |
| |
| model_name, model_conf = InferenceProcessor.validate_text_field(raw_json.get("model_name")) |
| if model_name is None: |
| warnings.append("Model name invalid") |
| confidences.append(model_conf) |
| |
| |
| hp_raw = InferenceProcessor.clean_number(raw_json.get("horse_power")) |
| hp, hp_conf = InferenceProcessor.fix_horse_power(hp_raw, model_name) |
| if hp is None: |
| warnings.append("Horse power invalid") |
| confidences.append(hp_conf) |
| |
| |
| cost_raw = InferenceProcessor.clean_number(raw_json.get("asset_cost")) |
| cost, cost_conf = InferenceProcessor.validate_asset_cost(cost_raw) |
| if cost is None: |
| warnings.append("Asset cost invalid") |
| confidences.append(cost_conf) |
| |
| |
| field_confidence = round(sum(confidences) / len(confidences), 3) |
| |
| validated = { |
| "dealer_name": dealer, |
| "model_name": model_name, |
| "horse_power": hp, |
| "asset_cost": cost |
| } |
| |
| return validated, field_confidence, warnings |
| |
| @staticmethod |
| def process_invoice(image_path: str, doc_id: str = None) -> Dict: |
| """ |
| Complete invoice processing pipeline |
| |
| Args: |
| image_path: Path to invoice image |
| doc_id: Document identifier (optional) |
| |
| Returns: |
| dict: Complete JSON output with all fields |
| """ |
| total_start = time.time() |
| timing_breakdown = {} |
| |
| |
| if doc_id is None: |
| import os |
| doc_id = os.path.splitext(os.path.basename(image_path))[0] |
| |
| |
| t1 = time.time() |
| image = InferenceProcessor.preprocess_image(image_path) |
| timing_breakdown['image_preprocessing'] = round(time.time() - t1, 3) |
| |
| |
| t2 = time.time() |
| signature_info, stamp_info, signature_conf, stamp_conf = model_manager.detect_sign_stamp(image_path) |
| timing_breakdown['yolo_detection'] = round(time.time() - t2, 3) |
| |
| |
| t3 = time.time() |
| vlm_output, vlm_latency = InferenceProcessor.run_vlm_extraction(image) |
| timing_breakdown['vlm_inference'] = round(vlm_latency, 3) |
| |
| |
| image.close() |
| del image |
| |
| |
| t4 = time.time() |
| raw_json = InferenceProcessor.extract_json_from_output(vlm_output) |
| timing_breakdown['json_parsing'] = round(time.time() - t4, 3) |
| |
| |
| t5 = time.time() |
| validated_fields, field_confidence, warnings = InferenceProcessor.validate_prediction(raw_json) |
| timing_breakdown['validation'] = round(time.time() - t5, 3) |
| |
| |
| validated_fields["signature"] = signature_info |
| validated_fields["stamp"] = stamp_info |
| |
| |
| confidences = [field_confidence] |
| if signature_info["present"]: |
| confidences.append(signature_conf) |
| if stamp_info["present"]: |
| confidences.append(stamp_conf) |
| |
| overall_confidence = round(sum(confidences) / len(confidences), 3) |
| |
| |
| total_time = time.time() - total_start |
| cost_estimate = (COST_PER_GPU_HOUR * total_time) / 3600 |
| |
| |
| result = { |
| "doc_id": doc_id, |
| "fields": validated_fields, |
| "confidence": overall_confidence, |
| "processing_time_sec": round(total_time, 2), |
| "timing_breakdown": timing_breakdown, |
| "cost_estimate_usd": round(cost_estimate, 6), |
| "warnings": warnings if warnings else None |
| } |
| |
| return result |
|
|