Jwalit commited on
Commit
200e006
·
verified ·
1 Parent(s): 5138c52

Add training script

Browse files
Files changed (1) hide show
  1. train_kyc_vlm.py +237 -0
train_kyc_vlm.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Fine-tune Google Gemma 4 E4B-IT for KYC Document Extraction & Classification.
3
+
4
+ Model: google/gemma-4-E4B-it (Gemma4ForConditionalGeneration, ~8B params)
5
+ Method: QLoRA SFT (4-bit quantization + LoRA on text decoder)
6
+ Dataset: Jwalit/kyc-document-extraction-vlm (synthetic KYC documents)
7
+ Hardware: A100-large (80GB VRAM)
8
+ Output: Jwalit/gemma4-e4b-kyc-document-extractor
9
+
10
+ Reference implementation: TRL SFT VLM docs (https://huggingface.co/docs/trl/sft_trainer#training-vision-language-models)
11
+ """
12
+
13
+ import os
14
+ import torch
15
+ from datasets import load_dataset
16
+ from transformers import (
17
+ AutoProcessor,
18
+ AutoModelForImageTextToText,
19
+ BitsAndBytesConfig,
20
+ )
21
+ from peft import LoraConfig
22
+ from trl import SFTConfig, SFTTrainer
23
+
24
+ # ============================================================
25
+ # Configuration
26
+ # ============================================================
27
+
28
+ MODEL_ID = "google/gemma-4-E4B-it"
29
+ DATASET_ID = "Jwalit/kyc-document-extraction-vlm"
30
+ OUTPUT_DIR = "./gemma4-e4b-kyc-extractor"
31
+ HUB_MODEL_ID = "Jwalit/gemma4-e4b-kyc-document-extractor"
32
+
33
+ # Training hyperparameters (based on VLM SFT best practices)
34
+ LEARNING_RATE = 2e-4 # Higher LR for LoRA adapters
35
+ NUM_EPOCHS = 3
36
+ BATCH_SIZE = 2
37
+ GRADIENT_ACCUMULATION = 8 # Effective batch size = 2 * 8 = 16
38
+ MAX_SEQ_LENGTH = None # CRITICAL for VLMs: don't truncate image tokens
39
+
40
+ # LoRA config (target text decoder only, vision encoder stays frozen)
41
+ LORA_R = 16
42
+ LORA_ALPHA = 32
43
+ LORA_DROPOUT = 0.05
44
+
45
+ # ============================================================
46
+ # Setup Trackio monitoring via environment variables
47
+ # ============================================================
48
+
49
+ os.environ["TRACKIO_SPACE_ID"] = "Jwalit/kyc-trackio"
50
+ os.environ["TRACKIO_PROJECT"] = "kyc-document-extractor"
51
+
52
+ # ============================================================
53
+ # Load dataset
54
+ # ============================================================
55
+
56
+ print("Loading dataset...")
57
+ dataset = load_dataset(DATASET_ID)
58
+ train_dataset = dataset["train"]
59
+ eval_dataset = dataset["test"]
60
+
61
+ print(f"Train: {len(train_dataset)} samples")
62
+ print(f"Eval: {len(eval_dataset)} samples")
63
+ print(f"Sample keys: {train_dataset.column_names}")
64
+
65
+ # ============================================================
66
+ # Model & Processor setup
67
+ # ============================================================
68
+
69
+ print(f"\nLoading model: {MODEL_ID}")
70
+
71
+ # 4-bit quantization for memory efficiency
72
+ bnb_config = BitsAndBytesConfig(
73
+ load_in_4bit=True,
74
+ bnb_4bit_use_double_quant=True,
75
+ bnb_4bit_quant_type="nf4",
76
+ bnb_4bit_compute_dtype=torch.bfloat16,
77
+ )
78
+
79
+ # Load model
80
+ model = AutoModelForImageTextToText.from_pretrained(
81
+ MODEL_ID,
82
+ device_map="auto",
83
+ torch_dtype=torch.bfloat16,
84
+ quantization_config=bnb_config,
85
+ attn_implementation="flash_attention_2",
86
+ )
87
+
88
+ # Load processor
89
+ processor = AutoProcessor.from_pretrained(MODEL_ID)
90
+
91
+ # Ensure pad token is set
92
+ if processor.tokenizer.pad_token is None:
93
+ processor.tokenizer.pad_token = processor.tokenizer.eos_token
94
+
95
+ print(f"Model loaded: {model.__class__.__name__}")
96
+ print(f"Model device map: {model.hf_device_map if hasattr(model, 'hf_device_map') else 'N/A'}")
97
+
98
+ # ============================================================
99
+ # LoRA Configuration
100
+ # ============================================================
101
+
102
+ # Target only the text decoder layers (vision encoder stays frozen)
103
+ peft_config = LoraConfig(
104
+ r=LORA_R,
105
+ lora_alpha=LORA_ALPHA,
106
+ lora_dropout=LORA_DROPOUT,
107
+ bias="none",
108
+ task_type="CAUSAL_LM",
109
+ target_modules=[
110
+ "q_proj", "k_proj", "v_proj", "o_proj",
111
+ "gate_proj", "up_proj", "down_proj",
112
+ ],
113
+ )
114
+
115
+ print(f"\nLoRA config: r={LORA_R}, alpha={LORA_ALPHA}, dropout={LORA_DROPOUT}")
116
+ print(f"Target modules: {peft_config.target_modules}")
117
+
118
+ # ============================================================
119
+ # SFT Training Configuration
120
+ # ============================================================
121
+
122
+ training_args = SFTConfig(
123
+ output_dir=OUTPUT_DIR,
124
+
125
+ # Training schedule
126
+ num_train_epochs=NUM_EPOCHS,
127
+ per_device_train_batch_size=BATCH_SIZE,
128
+ per_device_eval_batch_size=1,
129
+ gradient_accumulation_steps=GRADIENT_ACCUMULATION,
130
+
131
+ # Learning rate
132
+ learning_rate=LEARNING_RATE,
133
+ lr_scheduler_type="cosine",
134
+ warmup_ratio=0.05,
135
+
136
+ # Precision & optimization
137
+ bf16=True,
138
+ optim="adamw_torch_fused",
139
+ gradient_checkpointing=True,
140
+
141
+ # VLM-specific: DO NOT truncate (image tokens get cut off)
142
+ max_length=None,
143
+
144
+ # Logging - plain text, no tqdm
145
+ logging_strategy="steps",
146
+ logging_steps=10,
147
+ logging_first_step=True,
148
+ disable_tqdm=True,
149
+ report_to="trackio",
150
+ run_name="gemma4-e4b-kyc-sft-qlora",
151
+
152
+ # Eval
153
+ eval_strategy="steps",
154
+ eval_steps=100,
155
+
156
+ # Saving
157
+ save_strategy="steps",
158
+ save_steps=200,
159
+ save_total_limit=3,
160
+ load_best_model_at_end=True,
161
+ metric_for_best_model="eval_loss",
162
+
163
+ # Hub push
164
+ push_to_hub=True,
165
+ hub_model_id=HUB_MODEL_ID,
166
+ hub_strategy="every_save",
167
+
168
+ # SFT-specific
169
+ assistant_only_loss=True, # Only train on assistant responses
170
+ )
171
+
172
+ # ============================================================
173
+ # Create Trainer
174
+ # ============================================================
175
+
176
+ print("\nInitializing SFTTrainer...")
177
+
178
+ trainer = SFTTrainer(
179
+ model=model,
180
+ args=training_args,
181
+ train_dataset=train_dataset,
182
+ eval_dataset=eval_dataset,
183
+ peft_config=peft_config,
184
+ processing_class=processor, # Use processor (not tokenizer) for VLMs
185
+ )
186
+
187
+ # Print trainable parameters
188
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
189
+ total_params = sum(p.numel() for p in model.parameters())
190
+ print(f"\nTrainable params: {trainable_params:,} / {total_params:,} ({100*trainable_params/total_params:.2f}%)")
191
+
192
+ # ============================================================
193
+ # Train!
194
+ # ============================================================
195
+
196
+ print("\n" + "="*60)
197
+ print("Starting training...")
198
+ print(f" Model: {MODEL_ID}")
199
+ print(f" Dataset: {DATASET_ID}")
200
+ print(f" Epochs: {NUM_EPOCHS}")
201
+ print(f" Batch size: {BATCH_SIZE} x {GRADIENT_ACCUMULATION} = {BATCH_SIZE * GRADIENT_ACCUMULATION}")
202
+ print(f" Learning rate: {LEARNING_RATE}")
203
+ print(f" LoRA rank: {LORA_R}")
204
+ print(f" Push to: {HUB_MODEL_ID}")
205
+ print("="*60 + "\n")
206
+
207
+ train_result = trainer.train()
208
+
209
+ # ============================================================
210
+ # Save & push final model
211
+ # ============================================================
212
+
213
+ print("\nSaving final model...")
214
+ trainer.save_model(OUTPUT_DIR)
215
+ trainer.push_to_hub()
216
+
217
+ # Log final metrics
218
+ metrics = train_result.metrics
219
+ print("\n" + "="*60)
220
+ print("Training completed!")
221
+ print(f" Final train loss: {metrics.get('train_loss', 'N/A')}")
222
+ print(f" Total steps: {metrics.get('total_flos', 'N/A')}")
223
+ print(f" Model saved to: {HUB_MODEL_ID}")
224
+ print(f" View at: https://huggingface.co/{HUB_MODEL_ID}")
225
+ print("="*60)
226
+
227
+ # ============================================================
228
+ # Run quick evaluation
229
+ # ============================================================
230
+
231
+ print("\nRunning final evaluation...")
232
+ eval_metrics = trainer.evaluate()
233
+ print(f" Eval loss: {eval_metrics.get('eval_loss', 'N/A')}")
234
+ print(f" Eval runtime: {eval_metrics.get('eval_runtime', 'N/A')}s")
235
+
236
+ print("\n✅ Training complete! Model is ready for vLLM deployment.")
237
+ print(f"🔗 https://huggingface.co/{HUB_MODEL_ID}")