YUNGHUI2024 commited on
Commit
05f3cf9
·
verified ·
1 Parent(s): e23264b

Add training pipeline v2.0 — DeepSeek-VL2-tiny × ChartQA LoRA

Browse files
Files changed (1) hide show
  1. train_pipeline.py +423 -0
train_pipeline.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ DeepSeek-VL2-tiny Chart Fine-tuning Pipeline v2.0
4
+ ═══════════════════════════════════════════════════
5
+ Dataset : HuggingFaceM4/ChartQA (fallback from YUNGHUI2024/deepseek-ocr2-chart-finetune)
6
+ Model : deepseek-ai/deepseek-vl2-tiny (1B active / 3B total, bf16 ≈6.3 GB)
7
+ Method : LoRA (r=16, target q/k/v/o_proj) + gradient checkpointing
8
+ VRAM : Tested on RTX 3060 12 GB (batch=1, grad_accum=16)
9
+ Tracking : Trackio (optional) — set env vars:
10
+ TRACKIO_SPACE_ID, TRACKIO_PROJECT
11
+ Output : YUNGHUI2024/deepseek-vl2-tiny-chartqa-lora
12
+
13
+ ═══ 本機快速開始 ═══════════════════════════════════════════════════
14
+ pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121
15
+ pip install "transformers>=4.40" "datasets>=2.18" peft accelerate trackio huggingface_hub pillow
16
+ git clone https://github.com/deepseek-ai/DeepSeek-VL2 && cd DeepSeek-VL2 && pip install -e . && cd ..
17
+
18
+ # 登入 HF Hub (push 用)
19
+ huggingface-cli login
20
+
21
+ python train_pipeline.py
22
+ """
23
+
24
+ import os, sys, subprocess, logging, math
25
+ from pathlib import Path
26
+ import torch
27
+ from datasets import load_dataset
28
+ from PIL import Image
29
+ from torch.utils.data import DataLoader, Dataset
30
+ from transformers import AutoModelForCausalLM, get_cosine_schedule_with_warmup
31
+ from torch.optim import AdamW
32
+ from peft import LoraConfig, get_peft_model, TaskType
33
+
34
+ logging.basicConfig(
35
+ level=logging.INFO,
36
+ format="%(asctime)s | %(levelname)s | %(message)s",
37
+ datefmt="%H:%M:%S",
38
+ )
39
+ log = logging.getLogger(__name__)
40
+
41
+ # ─── optional: auto-install deepseek_vl if not found ─────────────────────────
42
+ try:
43
+ from deepseek_vl.models import DeepseekVLV2Processor, DeepseekVLV2ForCausalLM
44
+ except ImportError:
45
+ log.info("deepseek_vl not found — installing from GitHub …")
46
+ subprocess.run(
47
+ [sys.executable, "-m", "pip", "install", "-q",
48
+ "git+https://github.com/deepseek-ai/DeepSeek-VL2.git"],
49
+ check=True,
50
+ )
51
+ from deepseek_vl.models import DeepseekVLV2Processor, DeepseekVLV2ForCausalLM
52
+
53
+ # ─── optional Trackio ─────────────────────────────────────────────────────────
54
+ _USE_TRACKIO = bool(os.getenv("TRACKIO_SPACE_ID") or os.getenv("TRACKIO_PROJECT"))
55
+ if _USE_TRACKIO:
56
+ import trackio
57
+
58
+ def tlog(metrics: dict):
59
+ if _USE_TRACKIO:
60
+ trackio.log(metrics)
61
+
62
+ def talert(title: str, text: str, level: str = "INFO"):
63
+ if _USE_TRACKIO:
64
+ trackio.alert(title=title, text=text, level=level)
65
+ log.info(f"[ALERT {level}] {title}: {text}")
66
+
67
+ # ──────────────────────────────────────────────────────────────────────────────
68
+ # ███ CONFIG ████████████████████████████████████████████████████████████████
69
+ # ──────────────────────────────────────────────────────────────────────────────
70
+ MODEL_ID = "deepseek-ai/deepseek-vl2-tiny"
71
+ DATASET_ID = "HuggingFaceM4/ChartQA"
72
+ HUB_MODEL_ID = "YUNGHUI2024/deepseek-vl2-tiny-chartqa-lora"
73
+ OUTPUT_DIR = "./output-deepseek-vl2-chartqa" # local folder
74
+
75
+ # LoRA
76
+ LORA_R = 16
77
+ LORA_ALPHA = 32
78
+ LORA_DROPOUT = 0.05
79
+ LORA_TARGETS = ["q_proj", "k_proj", "v_proj", "o_proj"]
80
+
81
+ # Training — tuned for 12 GB VRAM (RTX 3060)
82
+ LR = 2e-4
83
+ NUM_EPOCHS = 2
84
+ BATCH_SIZE = 1 # per-GPU
85
+ GRAD_ACCUM = 16 # effective batch = 16
86
+ LOG_EVERY = 20 # opt-steps
87
+ SAVE_STEPS = 200 # opt-steps
88
+
89
+ # Set to small int (e.g. 50) for a quick smoke-test; None = full dataset
90
+ MAX_TRAIN = None
91
+ MAX_VAL = 200 # cap val for speed
92
+
93
+ # Trackio
94
+ TRACKIO_SPACE = os.getenv("TRACKIO_SPACE_ID", "YUNGHUI2024/ml-intern-chartqa")
95
+ TRACKIO_PROJ = os.getenv("TRACKIO_PROJECT", "deepseek-vl2-chartqa")
96
+ RUN_NAME = f"vl2tiny_lora_r{LORA_R}_lr{LR}"
97
+
98
+ # ──────────────────────────────────────────────────────────────────────────────
99
+ # ███ TRACKIO INIT ██████████████████████████████████████████████████████████
100
+ # ────────────────���─────────────────────────────────────────────────────────────
101
+ if _USE_TRACKIO:
102
+ trackio.init(
103
+ project=TRACKIO_PROJ,
104
+ name=RUN_NAME,
105
+ space_id=TRACKIO_SPACE,
106
+ config={
107
+ "model": MODEL_ID, "dataset": DATASET_ID,
108
+ "lora_r": LORA_R, "lora_alpha": LORA_ALPHA,
109
+ "lr": LR, "epochs": NUM_EPOCHS,
110
+ "batch_size": BATCH_SIZE, "grad_accum": GRAD_ACCUM,
111
+ },
112
+ )
113
+ log.info(f"Trackio init — project={TRACKIO_PROJ} run={RUN_NAME}")
114
+
115
+ # ──────────────────────────────────────────────────────────────────────────────
116
+ # ███ PROCESSOR & MODEL █████████████████████████████████████████████████████
117
+ # ──────────────────────────────────────────────────────────────────────────────
118
+ log.info(f"Loading processor from {MODEL_ID} …")
119
+ processor: DeepseekVLV2Processor = DeepseekVLV2Processor.from_pretrained(MODEL_ID)
120
+ tokenizer = processor.tokenizer
121
+ tokenizer.padding_side = "right"
122
+
123
+ log.info(f"Loading model {MODEL_ID} → bf16 …")
124
+ model: DeepseekVLV2ForCausalLM = AutoModelForCausalLM.from_pretrained(
125
+ MODEL_ID,
126
+ trust_remote_code=True,
127
+ torch_dtype=torch.bfloat16,
128
+ )
129
+ model.config.use_cache = False
130
+
131
+ # Gradient checkpointing BEFORE LoRA wrapping (saves ~30–40% VRAM)
132
+ if hasattr(model, "gradient_checkpointing_enable"):
133
+ model.gradient_checkpointing_enable()
134
+ elif hasattr(model, "language_model"):
135
+ model.language_model.gradient_checkpointing_enable()
136
+
137
+ # ──────────────────────────────────────────────────────────────────────────────
138
+ # ███ LoRA ██████████████████████████████████████████████████████████████████
139
+ # ──────────────────────────────────────────────────────────────────────────────
140
+ lora_cfg = LoraConfig(
141
+ task_type=TaskType.CAUSAL_LM,
142
+ r=LORA_R,
143
+ lora_alpha=LORA_ALPHA,
144
+ lora_dropout=LORA_DROPOUT,
145
+ target_modules=LORA_TARGETS,
146
+ bias="none",
147
+ )
148
+ model = get_peft_model(model, lora_cfg)
149
+ model.print_trainable_parameters()
150
+
151
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
152
+ total = sum(p.numel() for p in model.parameters())
153
+ log.info(f"LoRA trainable: {trainable/1e6:.2f}M / {total/1e6:.0f}M "
154
+ f"({100*trainable/total:.2f}%)")
155
+
156
+ # Move to GPU
157
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
158
+ model = model.to(device)
159
+ log.info(f"Model on {device}")
160
+ if torch.cuda.is_available():
161
+ mem = torch.cuda.memory_reserved() / 1e9
162
+ log.info(f"VRAM reserved after model load: {mem:.1f} GB")
163
+
164
+ # ──────────────────────────────────────────────────────────────────────────────
165
+ # ███ DATASET ███████████████████████████████████████████████████████████████
166
+ # ──────────────────────────────────────────────────────────────────────────────
167
+ log.info(f"Loading {DATASET_ID} …")
168
+ raw = load_dataset(DATASET_ID)
169
+ train_raw = raw["train"]
170
+ val_raw = raw["val"]
171
+
172
+ if MAX_TRAIN:
173
+ train_raw = train_raw.select(range(MAX_TRAIN))
174
+ val_raw = val_raw.select(range(min(len(val_raw), MAX_VAL)))
175
+ log.info(f"Train: {len(train_raw):,} Val: {len(val_raw):,}")
176
+ tlog({"dataset/train_samples": len(train_raw), "dataset/val_samples": len(val_raw)})
177
+
178
+
179
+ class ChartQADataset(Dataset):
180
+ def __init__(self, hf_ds): self.data = hf_ds
181
+ def __len__(self): return len(self.data)
182
+
183
+ def __getitem__(self, idx):
184
+ row = self.data[idx]
185
+ image = row["image"]
186
+ if not isinstance(image, Image.Image):
187
+ image = Image.fromarray(image)
188
+ image = image.convert("RGB")
189
+ question = str(row["query"])
190
+ answer = row["label"][0] if isinstance(row["label"], list) else str(row["label"])
191
+ conversation = [
192
+ {"role": "<|User|>", "content": f"<image>\n{question}", "images": [image]},
193
+ {"role": "<|Assistant|>", "content": answer},
194
+ ]
195
+ return conversation, [image]
196
+
197
+
198
+ def _find_asst_start(ids, asst_tok_ids):
199
+ """Return index just AFTER the <|Assistant|> token sequence."""
200
+ for j in range(len(ids) - len(asst_tok_ids) + 1):
201
+ if ids[j: j + len(asst_tok_ids)] == asst_tok_ids:
202
+ return j + len(asst_tok_ids)
203
+ return None
204
+
205
+
206
+ _ASST_TOKEN_IDS = tokenizer.encode("<|Assistant|>", add_special_tokens=False)
207
+
208
+
209
+ def collate_fn(batch):
210
+ conversations, images_list = zip(*batch)
211
+ all_images = [img for imgs in images_list for img in imgs]
212
+ inputs = processor(
213
+ conversations=list(conversations),
214
+ images=all_images,
215
+ force_batchify=True,
216
+ system_prompt=(
217
+ "You are a helpful assistant that answers questions "
218
+ "about charts and graphs accurately and concisely."
219
+ ),
220
+ )
221
+ input_ids = inputs["input_ids"]
222
+ labels = input_ids.clone()
223
+ labels[input_ids == tokenizer.pad_token_id] = -100
224
+ # Mask user/system tokens — only compute loss on assistant reply
225
+ for i in range(labels.shape[0]):
226
+ asst_start = _find_asst_start(input_ids[i].tolist(), _ASST_TOKEN_IDS)
227
+ if asst_start is not None:
228
+ labels[i, :asst_start] = -100
229
+ inputs["labels"] = labels
230
+ return inputs
231
+
232
+
233
+ train_ds = ChartQADataset(train_raw)
234
+ val_ds = ChartQADataset(val_raw)
235
+ train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
236
+ collate_fn=collate_fn, num_workers=2, pin_memory=True)
237
+ val_dl = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
238
+ collate_fn=collate_fn, num_workers=2, pin_memory=True)
239
+ log.info(f"DataLoaders ready — {len(train_dl)} train steps/epoch")
240
+
241
+ # ──────────────────────────────────────────────────────────────────────────────
242
+ # ███ OPTIMIZER & SCHEDULER █████████████████████████████████████████████████
243
+ # ──────────────────────────────────────────────────────────────────────────────
244
+ optimizer = AdamW(
245
+ [p for p in model.parameters() if p.requires_grad],
246
+ lr=LR, weight_decay=0.01,
247
+ )
248
+ total_opt_steps = math.ceil(len(train_dl) / GRAD_ACCUM) * NUM_EPOCHS
249
+ warmup_steps = min(50, total_opt_steps // 10)
250
+ scheduler = get_cosine_schedule_with_warmup(
251
+ optimizer,
252
+ num_warmup_steps=warmup_steps,
253
+ num_training_steps=total_opt_steps,
254
+ )
255
+ log.info(f"Optimiser ready — total opt_steps={total_opt_steps} warmup={warmup_steps}")
256
+
257
+ # ──────────────────────────────────────────────────────────────────────────────
258
+ # ███ TRAINING LOOP █████████████████████████████████████████████████████████
259
+ # ──────────────────────────────────────────────────────────────────────────────
260
+ Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
261
+ global_step = 0
262
+ opt_step = 0
263
+ best_val_loss = float("inf")
264
+ running_loss = 0.0
265
+
266
+ log.info("=" * 60)
267
+ log.info("Training start")
268
+ log.info("=" * 60)
269
+
270
+ for epoch in range(1, NUM_EPOCHS + 1):
271
+ model.train()
272
+ running_loss = 0.0
273
+
274
+ for batch in train_dl:
275
+ batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
276
+ for k, v in batch.items()}
277
+
278
+ with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
279
+ inputs_embeds = model.prepare_inputs_embeds(**batch)
280
+ out = model.language_model(
281
+ inputs_embeds=inputs_embeds,
282
+ attention_mask=batch["attention_mask"],
283
+ labels=batch["labels"],
284
+ )
285
+ loss = out.loss / GRAD_ACCUM
286
+ loss.backward()
287
+ running_loss += out.loss.item()
288
+ global_step += 1
289
+
290
+ if global_step % GRAD_ACCUM == 0:
291
+ torch.nn.utils.clip_grad_norm_(
292
+ [p for p in model.parameters() if p.requires_grad], 1.0
293
+ )
294
+ optimizer.step()
295
+ scheduler.step()
296
+ optimizer.zero_grad()
297
+ opt_step += 1
298
+
299
+ if opt_step % LOG_EVERY == 0:
300
+ avg_loss = running_loss / (LOG_EVERY * GRAD_ACCUM)
301
+ lr_now = scheduler.get_last_lr()[0]
302
+ vram_gb = torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0
303
+ log.info(
304
+ f"Epoch {epoch} | step {opt_step}/{total_opt_steps} | "
305
+ f"loss={avg_loss:.4f} | lr={lr_now:.2e} | VRAM={vram_gb:.1f}GB"
306
+ )
307
+ tlog({
308
+ "train/loss": avg_loss, "train/lr": lr_now,
309
+ "train/vram_gb": vram_gb,
310
+ "epoch": epoch, "opt_step": opt_step,
311
+ })
312
+ if avg_loss > 10.0:
313
+ talert("loss_diverging",
314
+ f"loss={avg_loss:.2f} at step {opt_step} — reduce lr by 10x",
315
+ "ERROR")
316
+ elif avg_loss > 3.5 and opt_step > 100:
317
+ talert("slow_convergence",
318
+ f"loss={avg_loss:.2f} at step {opt_step} — check lr schedule",
319
+ "WARN")
320
+ running_loss = 0.0
321
+
322
+ if opt_step % SAVE_STEPS == 0:
323
+ ckpt = f"{OUTPUT_DIR}/checkpoint-{opt_step}"
324
+ model.save_pretrained(ckpt)
325
+ log.info(f"Checkpoint -> {ckpt}")
326
+ talert("checkpoint_saved", f"step={opt_step} -> {ckpt}", "INFO")
327
+
328
+ # ── Validation ────────────────────────────────────────────────────────────
329
+ model.eval()
330
+ val_loss, val_steps = 0.0, 0
331
+ with torch.no_grad():
332
+ for batch in val_dl:
333
+ batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
334
+ for k, v in batch.items()}
335
+ with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
336
+ inputs_embeds = model.prepare_inputs_embeds(**batch)
337
+ out = model.language_model(
338
+ inputs_embeds=inputs_embeds,
339
+ attention_mask=batch["attention_mask"],
340
+ labels=batch["labels"],
341
+ )
342
+ val_loss += out.loss.item()
343
+ val_steps += 1
344
+
345
+ avg_val = val_loss / max(val_steps, 1)
346
+ log.info(f"---- Epoch {epoch} val_loss={avg_val:.4f} best={best_val_loss:.4f} ----")
347
+ tlog({"val/loss": avg_val, "epoch": epoch})
348
+
349
+ if avg_val < best_val_loss:
350
+ best_val_loss = avg_val
351
+ model.save_pretrained(f"{OUTPUT_DIR}/best")
352
+ log.info(f"New best -> val_loss={best_val_loss:.4f}")
353
+ talert("new_best", f"val_loss={best_val_loss:.4f} epoch={epoch}", "INFO")
354
+ elif epoch > 1 and avg_val > best_val_loss * 1.05:
355
+ talert("val_degrading",
356
+ f"val_loss={avg_val:.4f} > best*1.05={best_val_loss*1.05:.4f} — possible overfit",
357
+ "WARN")
358
+
359
+ # ──────────────────────────────────────────────────────────────────────────────
360
+ # ███ PUSH TO HUB ███████████████████████████████████████████████████████████
361
+ # ──────────────────────────────────────────────────────────────────────────────
362
+ log.info(f"Pushing best checkpoint to {HUB_MODEL_ID} ...")
363
+ best_path = Path(f"{OUTPUT_DIR}/best")
364
+
365
+ # Model card
366
+ card = f"""---
367
+ license: other
368
+ tags:
369
+ - deepseek-vl2
370
+ - chart-qa
371
+ - vision-language
372
+ - lora
373
+ - peft
374
+ base_model: {MODEL_ID}
375
+ datasets:
376
+ - HuggingFaceM4/ChartQA
377
+ ---
378
+
379
+ # DeepSeek-VL2-tiny x ChartQA LoRA
380
+
381
+ Fine-tuned [`{MODEL_ID}`]({MODEL_ID}) on
382
+ [ChartQA](https://huggingface.co/datasets/HuggingFaceM4/ChartQA)
383
+ with LoRA (r={LORA_R}, a={LORA_ALPHA}).
384
+
385
+ | | |
386
+ |--|--|
387
+ | Base | `{MODEL_ID}` |
388
+ | LoRA r / a | {LORA_R} / {LORA_ALPHA} |
389
+ | Target modules | {', '.join(LORA_TARGETS)} |
390
+ | LR | {LR} |
391
+ | Epochs | {NUM_EPOCHS} |
392
+ | Effective batch | {BATCH_SIZE * GRAD_ACCUM} |
393
+ | Best val loss | {best_val_loss:.4f} |
394
+
395
+ ## Load adapter
396
+
397
+ ```python
398
+ from deepseek_vl.models import DeepseekVLV2Processor, DeepseekVLV2ForCausalLM
399
+ from peft import PeftModel
400
+ from transformers import AutoModelForCausalLM
401
+ import torch
402
+
403
+ model_id = "{MODEL_ID}"
404
+ adapter_id = "{HUB_MODEL_ID}"
405
+
406
+ base = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True,
407
+ torch_dtype=torch.bfloat16)
408
+ model = PeftModel.from_pretrained(base, adapter_id).eval().cuda()
409
+ proc = DeepseekVLV2Processor.from_pretrained(model_id)
410
+ ```
411
+ """
412
+ (best_path / "README.md").write_text(card, encoding="utf-8")
413
+
414
+ model.push_to_hub(HUB_MODEL_ID, commit_message="LoRA adapter — ChartQA fine-tune")
415
+ processor.tokenizer.push_to_hub(HUB_MODEL_ID, commit_message="Add tokenizer")
416
+
417
+ log.info(f"Done! https://huggingface.co/{HUB_MODEL_ID}")
418
+ talert("training_complete",
419
+ f"best_val_loss={best_val_loss:.4f} model -> https://huggingface.co/{HUB_MODEL_ID}",
420
+ "INFO")
421
+
422
+ if _USE_TRACKIO:
423
+ trackio.finish()