Jwalit commited on
Commit
7357eb4
·
verified ·
1 Parent(s): bf014b7

Add Colab training notebook for free GPU training

Browse files
Files changed (1) hide show
  1. train_kyc_colab.ipynb +362 -0
train_kyc_colab.ipynb ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# 🔍 Train Gemma 4 E4B for KYC Document Extraction & Classification\n",
8
+ "\n",
9
+ "**Free GPU Training on Google Colab**\n",
10
+ "\n",
11
+ "This notebook fine-tunes `google/gemma-4-E4B-it` using QLoRA SFT for:\n",
12
+ "- **Document Classification**: Aadhaar, PAN, Passport, Visa, Election Card\n",
13
+ "- **Field Extraction**: Extract all structured fields as JSON\n",
14
+ "\n",
15
+ "**Requirements**: Colab T4 (free) or L4/A100 (Colab Pro)\n",
16
+ "\n",
17
+ "| Resource | Link |\n",
18
+ "|----------|------|\n",
19
+ "| Dataset | [Jwalit/kyc-document-extraction-vlm](https://huggingface.co/datasets/Jwalit/kyc-document-extraction-vlm) |\n",
20
+ "| Model Repo | [Jwalit/gemma4-e4b-kyc-document-extractor](https://huggingface.co/Jwalit/gemma4-e4b-kyc-document-extractor) |\n",
21
+ "| Base Model | [google/gemma-4-E4B-it](https://huggingface.co/google/gemma-4-E4B-it) |"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "markdown",
26
+ "metadata": {},
27
+ "source": [
28
+ "## 1. Install Dependencies"
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "code",
33
+ "execution_count": null,
34
+ "metadata": {},
35
+ "outputs": [],
36
+ "source": [
37
+ "!pip install -q torch transformers trl datasets peft accelerate bitsandbytes trackio pillow\n",
38
+ "!pip install -q flash-attn --no-build-isolation"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "markdown",
43
+ "metadata": {},
44
+ "source": [
45
+ "## 2. Login to Hugging Face"
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "code",
50
+ "execution_count": null,
51
+ "metadata": {},
52
+ "outputs": [],
53
+ "source": [
54
+ "from huggingface_hub import notebook_login\n",
55
+ "notebook_login()"
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "markdown",
60
+ "metadata": {},
61
+ "source": [
62
+ "## 3. Check GPU"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "code",
67
+ "execution_count": null,
68
+ "metadata": {},
69
+ "outputs": [],
70
+ "source": [
71
+ "import torch\n",
72
+ "print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n",
73
+ "print(f\"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB\")\n",
74
+ "print(f\"CUDA: {torch.version.cuda}\")\n",
75
+ "print(f\"PyTorch: {torch.__version__}\")"
76
+ ]
77
+ },
78
+ {
79
+ "cell_type": "markdown",
80
+ "metadata": {},
81
+ "source": [
82
+ "## 4. Load Dataset"
83
+ ]
84
+ },
85
+ {
86
+ "cell_type": "code",
87
+ "execution_count": null,
88
+ "metadata": {},
89
+ "outputs": [],
90
+ "source": [
91
+ "from datasets import load_dataset\n",
92
+ "\n",
93
+ "DATASET_ID = \"Jwalit/kyc-document-extraction-vlm\"\n",
94
+ "dataset = load_dataset(DATASET_ID)\n",
95
+ "train_dataset = dataset[\"train\"]\n",
96
+ "eval_dataset = dataset[\"test\"]\n",
97
+ "\n",
98
+ "print(f\"Train: {len(train_dataset)} samples\")\n",
99
+ "print(f\"Eval: {len(eval_dataset)} samples\")\n",
100
+ "print(f\"Columns: {train_dataset.column_names}\")\n",
101
+ "\n",
102
+ "# Preview a sample\n",
103
+ "sample = train_dataset[0]\n",
104
+ "print(f\"\\nSample message roles: {[m['role'] for m in sample['messages']]}\")\n",
105
+ "print(f\"Num images: {len(sample['images'])}\")"
106
+ ]
107
+ },
108
+ {
109
+ "cell_type": "markdown",
110
+ "metadata": {},
111
+ "source": [
112
+ "## 5. Load Model with QLoRA (4-bit)"
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "code",
117
+ "execution_count": null,
118
+ "metadata": {},
119
+ "outputs": [],
120
+ "source": [
121
+ "import torch\n",
122
+ "from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig\n",
123
+ "\n",
124
+ "MODEL_ID = \"google/gemma-4-E4B-it\"\n",
125
+ "\n",
126
+ "# 4-bit quantization\n",
127
+ "bnb_config = BitsAndBytesConfig(\n",
128
+ " load_in_4bit=True,\n",
129
+ " bnb_4bit_use_double_quant=True,\n",
130
+ " bnb_4bit_quant_type=\"nf4\",\n",
131
+ " bnb_4bit_compute_dtype=torch.bfloat16,\n",
132
+ ")\n",
133
+ "\n",
134
+ "print(f\"Loading {MODEL_ID}...\")\n",
135
+ "model = AutoModelForImageTextToText.from_pretrained(\n",
136
+ " MODEL_ID,\n",
137
+ " device_map=\"auto\",\n",
138
+ " torch_dtype=torch.bfloat16,\n",
139
+ " quantization_config=bnb_config,\n",
140
+ " attn_implementation=\"flash_attention_2\",\n",
141
+ ")\n",
142
+ "\n",
143
+ "processor = AutoProcessor.from_pretrained(MODEL_ID)\n",
144
+ "if processor.tokenizer.pad_token is None:\n",
145
+ " processor.tokenizer.pad_token = processor.tokenizer.eos_token\n",
146
+ "\n",
147
+ "print(f\"Model loaded: {model.__class__.__name__}\")\n",
148
+ "print(f\"GPU memory used: {torch.cuda.memory_allocated() / 1e9:.2f} GB\")"
149
+ ]
150
+ },
151
+ {
152
+ "cell_type": "markdown",
153
+ "metadata": {},
154
+ "source": [
155
+ "## 6. Configure LoRA & Training"
156
+ ]
157
+ },
158
+ {
159
+ "cell_type": "code",
160
+ "execution_count": null,
161
+ "metadata": {},
162
+ "outputs": [],
163
+ "source": [
164
+ "import os\n",
165
+ "from peft import LoraConfig\n",
166
+ "from trl import SFTConfig, SFTTrainer\n",
167
+ "\n",
168
+ "# ===== YOUR SETTINGS =====\n",
169
+ "HUB_MODEL_ID = \"Jwalit/gemma4-e4b-kyc-document-extractor\" # Change to your username!\n",
170
+ "OUTPUT_DIR = \"./gemma4-kyc-extractor\"\n",
171
+ "\n",
172
+ "# Trackio monitoring (optional)\n",
173
+ "os.environ[\"TRACKIO_SPACE_ID\"] = \"Jwalit/kyc-trackio\" # Change to your space\n",
174
+ "os.environ[\"TRACKIO_PROJECT\"] = \"kyc-document-extractor\"\n",
175
+ "# =========================\n",
176
+ "\n",
177
+ "# LoRA: target text decoder only (vision encoder stays frozen)\n",
178
+ "peft_config = LoraConfig(\n",
179
+ " r=16,\n",
180
+ " lora_alpha=32,\n",
181
+ " lora_dropout=0.05,\n",
182
+ " bias=\"none\",\n",
183
+ " task_type=\"CAUSAL_LM\",\n",
184
+ " target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
185
+ ")\n",
186
+ "\n",
187
+ "# SFT config optimized for T4 (16GB VRAM)\n",
188
+ "training_args = SFTConfig(\n",
189
+ " output_dir=OUTPUT_DIR,\n",
190
+ " num_train_epochs=3,\n",
191
+ " per_device_train_batch_size=1, # T4: batch=1, accumulate=16\n",
192
+ " per_device_eval_batch_size=1,\n",
193
+ " gradient_accumulation_steps=16, # Effective batch = 16\n",
194
+ " learning_rate=2e-4,\n",
195
+ " lr_scheduler_type=\"cosine\",\n",
196
+ " warmup_ratio=0.05,\n",
197
+ " bf16=True,\n",
198
+ " optim=\"adamw_torch_fused\",\n",
199
+ " gradient_checkpointing=True,\n",
200
+ " max_length=None, # CRITICAL for VLMs\n",
201
+ " logging_strategy=\"steps\",\n",
202
+ " logging_steps=10,\n",
203
+ " logging_first_step=True,\n",
204
+ " disable_tqdm=False, # Keep tqdm in Colab\n",
205
+ " report_to=\"trackio\",\n",
206
+ " run_name=\"gemma4-kyc-colab\",\n",
207
+ " eval_strategy=\"steps\",\n",
208
+ " eval_steps=100,\n",
209
+ " save_strategy=\"steps\",\n",
210
+ " save_steps=200,\n",
211
+ " save_total_limit=2,\n",
212
+ " load_best_model_at_end=True,\n",
213
+ " metric_for_best_model=\"eval_loss\",\n",
214
+ " push_to_hub=True,\n",
215
+ " hub_model_id=HUB_MODEL_ID,\n",
216
+ " hub_strategy=\"every_save\",\n",
217
+ " assistant_only_loss=True,\n",
218
+ ")\n",
219
+ "\n",
220
+ "print(\"Config ready!\")\n",
221
+ "print(f\" Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}\")\n",
222
+ "print(f\" Push to: {HUB_MODEL_ID}\")"
223
+ ]
224
+ },
225
+ {
226
+ "cell_type": "markdown",
227
+ "metadata": {},
228
+ "source": [
229
+ "## 7. Train! 🚀"
230
+ ]
231
+ },
232
+ {
233
+ "cell_type": "code",
234
+ "execution_count": null,
235
+ "metadata": {},
236
+ "outputs": [],
237
+ "source": [
238
+ "trainer = SFTTrainer(\n",
239
+ " model=model,\n",
240
+ " args=training_args,\n",
241
+ " train_dataset=train_dataset,\n",
242
+ " eval_dataset=eval_dataset,\n",
243
+ " peft_config=peft_config,\n",
244
+ " processing_class=processor,\n",
245
+ ")\n",
246
+ "\n",
247
+ "# Print trainable params\n",
248
+ "trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
249
+ "total = sum(p.numel() for p in model.parameters())\n",
250
+ "print(f\"Trainable: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)\")\n",
251
+ "\n",
252
+ "# Train\n",
253
+ "train_result = trainer.train()\n",
254
+ "\n",
255
+ "# Save & push\n",
256
+ "trainer.save_model(OUTPUT_DIR)\n",
257
+ "trainer.push_to_hub()\n",
258
+ "\n",
259
+ "print(f\"\\n✅ Done! Model at: https://huggingface.co/{HUB_MODEL_ID}\")"
260
+ ]
261
+ },
262
+ {
263
+ "cell_type": "markdown",
264
+ "metadata": {},
265
+ "source": [
266
+ "## 8. Test Inference"
267
+ ]
268
+ },
269
+ {
270
+ "cell_type": "code",
271
+ "execution_count": null,
272
+ "metadata": {},
273
+ "outputs": [],
274
+ "source": [
275
+ "# Test on a sample from the eval set\n",
276
+ "test_sample = eval_dataset[0]\n",
277
+ "test_image = test_sample[\"images\"][0]\n",
278
+ "\n",
279
+ "# Display the document\n",
280
+ "from IPython.display import display\n",
281
+ "display(test_image)\n",
282
+ "\n",
283
+ "# Run inference\n",
284
+ "messages = [\n",
285
+ " {\"role\": \"system\", \"content\": [{\"type\": \"text\", \"text\": \"You are an expert KYC document analyst. Always respond with accurate, structured JSON output.\"}]},\n",
286
+ " {\"role\": \"user\", \"content\": [\n",
287
+ " {\"type\": \"image\"},\n",
288
+ " {\"type\": \"text\", \"text\": \"Classify this document and extract all information as structured JSON.\"}\n",
289
+ " ]}\n",
290
+ "]\n",
291
+ "\n",
292
+ "inputs = processor.apply_chat_template(\n",
293
+ " messages, add_generation_prompt=True, tokenize=True,\n",
294
+ " return_dict=True, return_tensors=\"pt\", images=[test_image]\n",
295
+ ").to(model.device)\n",
296
+ "\n",
297
+ "with torch.no_grad():\n",
298
+ " output = model.generate(**inputs, max_new_tokens=1024, temperature=0.1, do_sample=True)\n",
299
+ "\n",
300
+ "result = processor.batch_decode(output[:, inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True)[0]\n",
301
+ "\n",
302
+ "import json\n",
303
+ "print(\"\\n📄 Model Output:\")\n",
304
+ "try:\n",
305
+ " print(json.dumps(json.loads(result), indent=2))\n",
306
+ "except:\n",
307
+ " print(result)\n",
308
+ "\n",
309
+ "print(\"\\n📋 Ground Truth:\")\n",
310
+ "gt_msg = test_sample[\"messages\"][-1] # assistant message\n",
311
+ "gt_text = gt_msg[\"content\"][0][\"text\"] if isinstance(gt_msg[\"content\"], list) else gt_msg[\"content\"]\n",
312
+ "try:\n",
313
+ " print(json.dumps(json.loads(gt_text), indent=2))\n",
314
+ "except:\n",
315
+ " print(gt_text)"
316
+ ]
317
+ },
318
+ {
319
+ "cell_type": "markdown",
320
+ "metadata": {},
321
+ "source": [
322
+ "## 9. Deploy with vLLM\n",
323
+ "\n",
324
+ "After training, deploy the model with vLLM for production speed:\n",
325
+ "\n",
326
+ "```bash\n",
327
+ "# Merge LoRA adapters first (optional but recommended)\n",
328
+ "python -c \"\n",
329
+ "from peft import AutoPeftModelForCausalLM\n",
330
+ "import torch\n",
331
+ "model = AutoPeftModelForCausalLM.from_pretrained('Jwalit/gemma4-e4b-kyc-document-extractor', device_map='auto', torch_dtype=torch.bfloat16)\n",
332
+ "merged = model.merge_and_unload()\n",
333
+ "merged.save_pretrained('./merged-kyc-extractor')\n",
334
+ "\"\n",
335
+ "\n",
336
+ "# Start vLLM server\n",
337
+ "python -m vllm.entrypoints.openai.api_server \\\n",
338
+ " --model ./merged-kyc-extractor \\\n",
339
+ " --max-model-len 4096 \\\n",
340
+ " --dtype bfloat16\n",
341
+ "```"
342
+ ]
343
+ }
344
+ ],
345
+ "metadata": {
346
+ "accelerator": "GPU",
347
+ "colab": {
348
+ "gpuType": "T4",
349
+ "provenance": []
350
+ },
351
+ "kernelspec": {
352
+ "display_name": "Python 3",
353
+ "name": "python3"
354
+ },
355
+ "language_info": {
356
+ "name": "python",
357
+ "version": "3.10.12"
358
+ }
359
+ },
360
+ "nbformat": 4,
361
+ "nbformat_minor": 0
362
+ }