--- license: apache-2.0 base_model: google/gemma-4-E4B-it tags: - sft - trl - peft - qlora - kyc - document-extraction - document-classification - aadhaar - pan-card - passport - visa - election-card - gemma4 - vision-language-model - vllm datasets: - Jwalit/kyc-document-extraction-vlm pipeline_tag: image-text-to-text library_name: transformers --- # Gemma 4 E4B — KYC Document Extractor & Classifier **Production-ready Vision-Language Model for Indian KYC Document Extraction and Classification** Fine-tuned from [`google/gemma-4-E4B-it`](https://huggingface.co/google/gemma-4-E4B-it) using QLoRA SFT on a synthetic KYC document dataset covering 5 Indian identity document types. ## 🎯 Capabilities | Task | Description | |------|-------------| | **Document Classification** | Classify document as: Aadhaar Card, PAN Card, Passport, Visa, or Election Card (Voter ID) | | **Field Extraction** | Extract all structured fields (name, DOB, ID number, address, etc.) as JSON | | **Combined** | Classify + Extract in a single pass | ## 📋 Supported Document Types | Document | Fields Extracted | |----------|-----------------| | **Aadhaar Card** | full_name, date_of_birth, gender, father_name, aadhaar_number, address, VID | | **PAN Card** | full_name, father_name, date_of_birth, pan_number | | **Passport** | surname, given_name, nationality, gender, date_of_birth, passport_number, place_of_birth, date_of_issue, date_of_expiry, place_of_issue | | **Visa** | issuing_country, visa_type, visa_category, visa_number, full_name, nationality, gender, date_of_birth, passport_number, date_of_issue, date_of_expiry, entries | | **Election Card** | voter_id, full_name, relative_name, gender, date_of_birth, age, state, constituency, address | ## 🚀 Quick Start ### With Transformers ```python import torch from transformers import AutoProcessor, AutoModelForImageTextToText from PIL import Image model_id = "Jwalit/gemma4-e4b-kyc-document-extractor" processor = AutoProcessor.from_pretrained(model_id) model = AutoModelForImageTextToText.from_pretrained( model_id, device_map="auto", torch_dtype=torch.bfloat16 ) image = Image.open("document.jpg").convert("RGB") messages = [ {"role": "system", "content": [{"type": "text", "text": "You are an expert KYC document analyst. Always respond with accurate, structured JSON output."}]}, {"role": "user", "content": [ {"type": "image"}, {"type": "text", "text": "Classify this document and extract all information as structured JSON."} ]} ] inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", images=[image] ).to(model.device) with torch.no_grad(): output = model.generate(**inputs, max_new_tokens=1024, temperature=0.1) result = processor.batch_decode(output[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True)[0] print(result) ``` ### With vLLM (Production Deployment) ```bash # Start OpenAI-compatible server python -m vllm.entrypoints.openai.api_server \ --model Jwalit/gemma4-e4b-kyc-document-extractor \ --trust-remote-code \ --max-model-len 4096 \ --dtype bfloat16 \ --gpu-memory-utilization 0.9 ``` ```python from openai import OpenAI import base64 client = OpenAI(base_url="http://localhost:8000/v1", api_key="dummy") with open("document.jpg", "rb") as f: img_b64 = base64.b64encode(f.read()).decode() response = client.chat.completions.create( model="Jwalit/gemma4-e4b-kyc-document-extractor", messages=[ {"role": "system", "content": "You are an expert KYC document analyst. Always respond with accurate, structured JSON output."}, {"role": "user", "content": [ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{img_b64}"}}, {"type": "text", "text": "Classify and extract all fields from this KYC document as JSON."} ]} ], max_tokens=1024, temperature=0.1 ) print(response.choices[0].message.content) ``` ### With vLLM Offline (Batch Processing) ```python from vllm import LLM, SamplingParams llm = LLM( model="Jwalit/gemma4-e4b-kyc-document-extractor", trust_remote_code=True, max_model_len=4096, dtype="bfloat16", ) sampling_params = SamplingParams(temperature=0.1, max_tokens=1024) # Use llm.chat() with image messages for batch processing ``` ## 🏋️ Training Details ### Method - **Base Model**: `google/gemma-4-E4B-it` (~8B params, Gemma4ForConditionalGeneration) - **Fine-tuning**: QLoRA SFT (4-bit NF4 quantization + LoRA rank-16 on text decoder) - **Vision Encoder**: Frozen SigLIP (280 tokens per image, 768-dim, 16 layers) - **Framework**: TRL SFTTrainer + PEFT + BitsAndBytes ### Hyperparameters | Parameter | Value | |-----------|-------| | Learning Rate | 2e-4 | | Epochs | 3 | | Batch Size | 2 × 8 (gradient accumulation) = 16 effective | | LoRA Rank (r) | 16 | | LoRA Alpha | 32 | | LoRA Dropout | 0.05 | | Optimizer | AdamW (fused) | | LR Scheduler | Cosine with 5% warmup | | Precision | bf16 | | Gradient Checkpointing | ✅ | | Target Modules | q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj | ### Dataset - **Dataset**: [`Jwalit/kyc-document-extraction-vlm`](https://huggingface.co/datasets/Jwalit/kyc-document-extraction-vlm) - **Size**: 2,704 train / 296 eval samples - **Document Types**: 5 (Aadhaar, PAN, Passport, Visa, Election Card) - **Task Types**: Classification, Extraction, Combined (balanced across all) - **Format**: Conversational VLM (messages with `{"type": "image"}` + `{"type": "text"}`) ### Architecture ``` Gemma4ForConditionalGeneration ├── Vision Encoder (SigLIP, FROZEN) │ ├── 16 layers, 768-dim, 12 attention heads │ ├── Patch size: 16, Pooling kernel: 3 │ └── Output: 280 soft tokens per image ├── Text Decoder (LoRA applied here) │ ├── 42 layers (36 sliding + 6 full attention) │ ├── 2560 hidden, 8 heads, GQA │ ├── 262K vocab, 131K context │ └── LoRA on: q/k/v/o_proj + gate/up/down_proj └── Audio Encoder (unused, frozen) ``` ## 🔧 Reproduce Training ```bash # Install dependencies pip install torch transformers trl datasets peft accelerate bitsandbytes trackio flash-attn pillow # Run training (requires GPU with ≥24GB VRAM, recommended: A100 80GB) python train_kyc_vlm.py ``` Or via TRL CLI: ```bash trl sft \ --model_name_or_path google/gemma-4-E4B-it \ --dataset_name Jwalit/kyc-document-extraction-vlm \ --output_dir ./gemma4-kyc-extractor \ --learning_rate 2e-4 \ --num_train_epochs 3 \ --per_device_train_batch_size 2 \ --gradient_accumulation_steps 8 \ --bf16 \ --gradient_checkpointing \ --push_to_hub \ --hub_model_id Jwalit/gemma4-e4b-kyc-document-extractor ``` ## ⚡ Performance & Deployment Notes - **vLLM compatible**: Native support via `Gemma4ForConditionalGeneration` architecture - **280 image tokens**: Efficient — processes document images in ~280 tokens (vs 1024+ for other VLMs) - **128K context**: Can handle multiple document pages in a single request - **QLoRA deployment**: Merge adapters for full-speed inference, or serve with PEFT for memory efficiency ### Merging Adapters (for production — recommended before vLLM serving) ```python from peft import AutoPeftModelForCausalLM import torch model = AutoPeftModelForCausalLM.from_pretrained( "Jwalit/gemma4-e4b-kyc-document-extractor", device_map="auto", torch_dtype=torch.bfloat16, ) merged_model = model.merge_and_unload() merged_model.save_pretrained("./merged-kyc-extractor") # Then push merged model for faster vLLM serving ``` ## 📊 Expected Output Format ```json { "document_type": "aadhaar_card", "full_name": "Rajesh Kumar Singh", "date_of_birth": "15/03/1985", "gender": "Male", "father_name": "Suresh Kumar Singh", "aadhaar_number": "1234 5678 9012", "address": "123, MG Road, Mumbai, Maharashtra - 400001", "vid": "1234 5678 9012 3456" } ``` ## ⚠️ Limitations - Trained on **synthetic** KYC documents — accuracy on real-world documents will improve with fine-tuning on real (anonymized) KYC samples - Best results when further fine-tuned with 200-500 real document images per type - Vision encoder is frozen — cannot learn new visual features beyond base SigLIP capabilities - Indian documents only (Aadhaar, PAN, Passport, Visa, Election Card) ## 📝 License Apache 2.0 (same as base model `google/gemma-4-E4B-it`)