File size: 7,730 Bytes
15dad3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
"""
vLLM Inference Script for Gemma 4 E4B KYC Document Extractor.

This script demonstrates how to serve the fine-tuned model using vLLM
for production-level speed and throughput.

Requirements:
    pip install vllm>=0.8.0 pillow

Usage:
    # Start vLLM server
    python -m vllm.entrypoints.openai.api_server \
        --model Jwalit/gemma4-e4b-kyc-document-extractor \
        --trust-remote-code \
        --max-model-len 4096 \
        --dtype bfloat16 \
        --gpu-memory-utilization 0.9

    # Or use this script directly for batch inference
    python inference_vllm.py --image path/to/document.jpg
"""

import argparse
import json
import base64
from io import BytesIO
from PIL import Image


def inference_with_vllm_offline(image_path: str, task: str = "combined"):
    """Run inference using vLLM offline mode (no server needed)."""
    from vllm import LLM, SamplingParams
    
    # Load model with vLLM
    llm = LLM(
        model="Jwalit/gemma4-e4b-kyc-document-extractor",
        trust_remote_code=True,
        max_model_len=4096,
        dtype="bfloat16",
        gpu_memory_utilization=0.9,
    )
    
    # Load image
    image = Image.open(image_path).convert("RGB")
    
    # Build prompt based on task
    if task == "classify":
        user_text = "What type of document is shown in this image? Respond with structured JSON."
    elif task == "extract":
        user_text = "Extract all relevant information from this document as a structured JSON."
    else:  # combined
        user_text = "First classify this document, then extract all information from it as structured JSON."
    
    system_text = (
        "You are an expert KYC document analyst. You can classify and extract information "
        "from Indian identity documents including Aadhaar cards, PAN cards, Passports, Visas, "
        "and Election Cards (Voter IDs). Always respond with accurate, structured JSON output."
    )
    
    # Format as chat messages for vLLM
    messages = [
        {"role": "system", "content": system_text},
        {
            "role": "user",
            "content": [
                {"type": "image_url", "image_url": {"url": f"file://{image_path}"}},
                {"type": "text", "text": user_text},
            ],
        },
    ]
    
    sampling_params = SamplingParams(
        temperature=0.1,
        top_p=0.95,
        max_tokens=1024,
        stop=["<end_of_turn>"],
    )
    
    outputs = llm.chat(messages, sampling_params=sampling_params)
    
    result = outputs[0].outputs[0].text
    return result


def inference_with_transformers(image_path: str, task: str = "combined"):
    """Run inference using HuggingFace Transformers (works without vLLM)."""
    import torch
    from transformers import AutoProcessor, AutoModelForImageTextToText
    
    model_id = "Jwalit/gemma4-e4b-kyc-document-extractor"
    
    print(f"Loading model: {model_id}")
    processor = AutoProcessor.from_pretrained(model_id)
    model = AutoModelForImageTextToText.from_pretrained(
        model_id,
        device_map="auto",
        torch_dtype=torch.bfloat16,
    )
    model.eval()
    
    # Load image
    image = Image.open(image_path).convert("RGB")
    
    # Build prompt
    if task == "classify":
        user_text = "What type of document is shown in this image? Respond with structured JSON."
    elif task == "extract":
        user_text = "Extract all relevant information from this document as a structured JSON."
    else:
        user_text = "First classify this document, then extract all information from it as structured JSON."
    
    messages = [
        {
            "role": "system",
            "content": [{"type": "text", "text": "You are an expert KYC document analyst. You can classify and extract information from Indian identity documents including Aadhaar cards, PAN cards, Passports, Visas, and Election Cards (Voter IDs). Always respond with accurate, structured JSON output."}],
        },
        {
            "role": "user",
            "content": [
                {"type": "image"},
                {"type": "text", "text": user_text},
            ],
        },
    ]
    
    # Process inputs
    inputs = processor.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
        images=[image],
    ).to(model.device)
    
    # Generate
    with torch.no_grad():
        output_ids = model.generate(
            **inputs,
            max_new_tokens=1024,
            temperature=0.1,
            top_p=0.95,
            do_sample=True,
        )
    
    # Decode only new tokens
    generated_ids = output_ids[:, inputs["input_ids"].shape[1]:]
    result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    
    return result


def inference_with_openai_client(image_path: str, task: str = "combined", 
                                  api_base: str = "http://localhost:8000/v1"):
    """Call a running vLLM OpenAI-compatible server."""
    from openai import OpenAI
    
    client = OpenAI(base_url=api_base, api_key="dummy")
    
    # Encode image to base64
    image = Image.open(image_path).convert("RGB")
    buffer = BytesIO()
    image.save(buffer, format="PNG")
    img_b64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
    
    if task == "classify":
        user_text = "What type of document is shown in this image? Respond with structured JSON."
    elif task == "extract":
        user_text = "Extract all relevant information from this document as a structured JSON."
    else:
        user_text = "First classify this document, then extract all information from it as structured JSON."
    
    response = client.chat.completions.create(
        model="Jwalit/gemma4-e4b-kyc-document-extractor",
        messages=[
            {
                "role": "system",
                "content": "You are an expert KYC document analyst. You can classify and extract information from Indian identity documents including Aadhaar cards, PAN cards, Passports, Visas, and Election Cards (Voter IDs). Always respond with accurate, structured JSON output.",
            },
            {
                "role": "user",
                "content": [
                    {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_b64}"}},
                    {"type": "text", "text": user_text},
                ],
            },
        ],
        max_tokens=1024,
        temperature=0.1,
    )
    
    return response.choices[0].message.content


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="KYC Document Extraction Inference")
    parser.add_argument("--image", type=str, required=True, help="Path to document image")
    parser.add_argument("--task", choices=["classify", "extract", "combined"], default="combined")
    parser.add_argument("--backend", choices=["vllm", "transformers", "api"], default="transformers")
    parser.add_argument("--api-base", type=str, default="http://localhost:8000/v1")
    args = parser.parse_args()
    
    print(f"\n🔍 KYC Document Analysis")
    print(f"  Image: {args.image}")
    print(f"  Task: {args.task}")
    print(f"  Backend: {args.backend}\n")
    
    if args.backend == "vllm":
        result = inference_with_vllm_offline(args.image, args.task)
    elif args.backend == "api":
        result = inference_with_openai_client(args.image, args.task, args.api_base)
    else:
        result = inference_with_transformers(args.image, args.task)
    
    print("📄 Result:")
    try:
        parsed = json.loads(result)
        print(json.dumps(parsed, indent=2))
    except json.JSONDecodeError:
        print(result)