Jwalit's picture
Add inference script with vLLM, Transformers, and OpenAI API backends
15dad3d verified
"""
vLLM Inference Script for Gemma 4 E4B KYC Document Extractor.
This script demonstrates how to serve the fine-tuned model using vLLM
for production-level speed and throughput.
Requirements:
pip install vllm>=0.8.0 pillow
Usage:
# Start vLLM server
python -m vllm.entrypoints.openai.api_server \
--model Jwalit/gemma4-e4b-kyc-document-extractor \
--trust-remote-code \
--max-model-len 4096 \
--dtype bfloat16 \
--gpu-memory-utilization 0.9
# Or use this script directly for batch inference
python inference_vllm.py --image path/to/document.jpg
"""
import argparse
import json
import base64
from io import BytesIO
from PIL import Image
def inference_with_vllm_offline(image_path: str, task: str = "combined"):
"""Run inference using vLLM offline mode (no server needed)."""
from vllm import LLM, SamplingParams
# Load model with vLLM
llm = LLM(
model="Jwalit/gemma4-e4b-kyc-document-extractor",
trust_remote_code=True,
max_model_len=4096,
dtype="bfloat16",
gpu_memory_utilization=0.9,
)
# Load image
image = Image.open(image_path).convert("RGB")
# Build prompt based on task
if task == "classify":
user_text = "What type of document is shown in this image? Respond with structured JSON."
elif task == "extract":
user_text = "Extract all relevant information from this document as a structured JSON."
else: # combined
user_text = "First classify this document, then extract all information from it as structured JSON."
system_text = (
"You are an expert KYC document analyst. You can classify and extract information "
"from Indian identity documents including Aadhaar cards, PAN cards, Passports, Visas, "
"and Election Cards (Voter IDs). Always respond with accurate, structured JSON output."
)
# Format as chat messages for vLLM
messages = [
{"role": "system", "content": system_text},
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": f"file://{image_path}"}},
{"type": "text", "text": user_text},
],
},
]
sampling_params = SamplingParams(
temperature=0.1,
top_p=0.95,
max_tokens=1024,
stop=["<end_of_turn>"],
)
outputs = llm.chat(messages, sampling_params=sampling_params)
result = outputs[0].outputs[0].text
return result
def inference_with_transformers(image_path: str, task: str = "combined"):
"""Run inference using HuggingFace Transformers (works without vLLM)."""
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText
model_id = "Jwalit/gemma4-e4b-kyc-document-extractor"
print(f"Loading model: {model_id}")
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForImageTextToText.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
)
model.eval()
# Load image
image = Image.open(image_path).convert("RGB")
# Build prompt
if task == "classify":
user_text = "What type of document is shown in this image? Respond with structured JSON."
elif task == "extract":
user_text = "Extract all relevant information from this document as a structured JSON."
else:
user_text = "First classify this document, then extract all information from it as structured JSON."
messages = [
{
"role": "system",
"content": [{"type": "text", "text": "You are an expert KYC document analyst. You can classify and extract information from Indian identity documents including Aadhaar cards, PAN cards, Passports, Visas, and Election Cards (Voter IDs). Always respond with accurate, structured JSON output."}],
},
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": user_text},
],
},
]
# Process inputs
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
images=[image],
).to(model.device)
# Generate
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=1024,
temperature=0.1,
top_p=0.95,
do_sample=True,
)
# Decode only new tokens
generated_ids = output_ids[:, inputs["input_ids"].shape[1]:]
result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return result
def inference_with_openai_client(image_path: str, task: str = "combined",
api_base: str = "http://localhost:8000/v1"):
"""Call a running vLLM OpenAI-compatible server."""
from openai import OpenAI
client = OpenAI(base_url=api_base, api_key="dummy")
# Encode image to base64
image = Image.open(image_path).convert("RGB")
buffer = BytesIO()
image.save(buffer, format="PNG")
img_b64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
if task == "classify":
user_text = "What type of document is shown in this image? Respond with structured JSON."
elif task == "extract":
user_text = "Extract all relevant information from this document as a structured JSON."
else:
user_text = "First classify this document, then extract all information from it as structured JSON."
response = client.chat.completions.create(
model="Jwalit/gemma4-e4b-kyc-document-extractor",
messages=[
{
"role": "system",
"content": "You are an expert KYC document analyst. You can classify and extract information from Indian identity documents including Aadhaar cards, PAN cards, Passports, Visas, and Election Cards (Voter IDs). Always respond with accurate, structured JSON output.",
},
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_b64}"}},
{"type": "text", "text": user_text},
],
},
],
max_tokens=1024,
temperature=0.1,
)
return response.choices[0].message.content
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="KYC Document Extraction Inference")
parser.add_argument("--image", type=str, required=True, help="Path to document image")
parser.add_argument("--task", choices=["classify", "extract", "combined"], default="combined")
parser.add_argument("--backend", choices=["vllm", "transformers", "api"], default="transformers")
parser.add_argument("--api-base", type=str, default="http://localhost:8000/v1")
args = parser.parse_args()
print(f"\n🔍 KYC Document Analysis")
print(f" Image: {args.image}")
print(f" Task: {args.task}")
print(f" Backend: {args.backend}\n")
if args.backend == "vllm":
result = inference_with_vllm_offline(args.image, args.task)
elif args.backend == "api":
result = inference_with_openai_client(args.image, args.task, args.api_base)
else:
result = inference_with_transformers(args.image, args.task)
print("📄 Result:")
try:
parsed = json.loads(result)
print(json.dumps(parsed, indent=2))
except json.JSONDecodeError:
print(result)