""" vLLM Inference Script for Gemma 4 E4B KYC Document Extractor. This script demonstrates how to serve the fine-tuned model using vLLM for production-level speed and throughput. Requirements: pip install vllm>=0.8.0 pillow Usage: # Start vLLM server python -m vllm.entrypoints.openai.api_server \ --model Jwalit/gemma4-e4b-kyc-document-extractor \ --trust-remote-code \ --max-model-len 4096 \ --dtype bfloat16 \ --gpu-memory-utilization 0.9 # Or use this script directly for batch inference python inference_vllm.py --image path/to/document.jpg """ import argparse import json import base64 from io import BytesIO from PIL import Image def inference_with_vllm_offline(image_path: str, task: str = "combined"): """Run inference using vLLM offline mode (no server needed).""" from vllm import LLM, SamplingParams # Load model with vLLM llm = LLM( model="Jwalit/gemma4-e4b-kyc-document-extractor", trust_remote_code=True, max_model_len=4096, dtype="bfloat16", gpu_memory_utilization=0.9, ) # Load image image = Image.open(image_path).convert("RGB") # Build prompt based on task if task == "classify": user_text = "What type of document is shown in this image? Respond with structured JSON." elif task == "extract": user_text = "Extract all relevant information from this document as a structured JSON." else: # combined user_text = "First classify this document, then extract all information from it as structured JSON." system_text = ( "You are an expert KYC document analyst. You can classify and extract information " "from Indian identity documents including Aadhaar cards, PAN cards, Passports, Visas, " "and Election Cards (Voter IDs). Always respond with accurate, structured JSON output." ) # Format as chat messages for vLLM messages = [ {"role": "system", "content": system_text}, { "role": "user", "content": [ {"type": "image_url", "image_url": {"url": f"file://{image_path}"}}, {"type": "text", "text": user_text}, ], }, ] sampling_params = SamplingParams( temperature=0.1, top_p=0.95, max_tokens=1024, stop=[""], ) outputs = llm.chat(messages, sampling_params=sampling_params) result = outputs[0].outputs[0].text return result def inference_with_transformers(image_path: str, task: str = "combined"): """Run inference using HuggingFace Transformers (works without vLLM).""" import torch from transformers import AutoProcessor, AutoModelForImageTextToText model_id = "Jwalit/gemma4-e4b-kyc-document-extractor" print(f"Loading model: {model_id}") processor = AutoProcessor.from_pretrained(model_id) model = AutoModelForImageTextToText.from_pretrained( model_id, device_map="auto", torch_dtype=torch.bfloat16, ) model.eval() # Load image image = Image.open(image_path).convert("RGB") # Build prompt if task == "classify": user_text = "What type of document is shown in this image? Respond with structured JSON." elif task == "extract": user_text = "Extract all relevant information from this document as a structured JSON." else: user_text = "First classify this document, then extract all information from it as structured JSON." messages = [ { "role": "system", "content": [{"type": "text", "text": "You are an expert KYC document analyst. You can classify and extract information from Indian identity documents including Aadhaar cards, PAN cards, Passports, Visas, and Election Cards (Voter IDs). Always respond with accurate, structured JSON output."}], }, { "role": "user", "content": [ {"type": "image"}, {"type": "text", "text": user_text}, ], }, ] # Process inputs inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", images=[image], ).to(model.device) # Generate with torch.no_grad(): output_ids = model.generate( **inputs, max_new_tokens=1024, temperature=0.1, top_p=0.95, do_sample=True, ) # Decode only new tokens generated_ids = output_ids[:, inputs["input_ids"].shape[1]:] result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] return result def inference_with_openai_client(image_path: str, task: str = "combined", api_base: str = "http://localhost:8000/v1"): """Call a running vLLM OpenAI-compatible server.""" from openai import OpenAI client = OpenAI(base_url=api_base, api_key="dummy") # Encode image to base64 image = Image.open(image_path).convert("RGB") buffer = BytesIO() image.save(buffer, format="PNG") img_b64 = base64.b64encode(buffer.getvalue()).decode("utf-8") if task == "classify": user_text = "What type of document is shown in this image? Respond with structured JSON." elif task == "extract": user_text = "Extract all relevant information from this document as a structured JSON." else: user_text = "First classify this document, then extract all information from it as structured JSON." response = client.chat.completions.create( model="Jwalit/gemma4-e4b-kyc-document-extractor", messages=[ { "role": "system", "content": "You are an expert KYC document analyst. You can classify and extract information from Indian identity documents including Aadhaar cards, PAN cards, Passports, Visas, and Election Cards (Voter IDs). Always respond with accurate, structured JSON output.", }, { "role": "user", "content": [ {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_b64}"}}, {"type": "text", "text": user_text}, ], }, ], max_tokens=1024, temperature=0.1, ) return response.choices[0].message.content if __name__ == "__main__": parser = argparse.ArgumentParser(description="KYC Document Extraction Inference") parser.add_argument("--image", type=str, required=True, help="Path to document image") parser.add_argument("--task", choices=["classify", "extract", "combined"], default="combined") parser.add_argument("--backend", choices=["vllm", "transformers", "api"], default="transformers") parser.add_argument("--api-base", type=str, default="http://localhost:8000/v1") args = parser.parse_args() print(f"\nšŸ” KYC Document Analysis") print(f" Image: {args.image}") print(f" Task: {args.task}") print(f" Backend: {args.backend}\n") if args.backend == "vllm": result = inference_with_vllm_offline(args.image, args.task) elif args.backend == "api": result = inference_with_openai_client(args.image, args.task, args.api_base) else: result = inference_with_transformers(args.image, args.task) print("šŸ“„ Result:") try: parsed = json.loads(result) print(json.dumps(parsed, indent=2)) except json.JSONDecodeError: print(result)