YUNGHUI2024 commited on
Commit
e23264b
Β·
verified Β·
1 Parent(s): 5044f54

Add smoke_test.py for local environment validation

Browse files
Files changed (1) hide show
  1. smoke_test.py +79 -0
smoke_test.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ smoke_test.py β€” ζœ¬ζ©ŸεΏ«ι€Ÿι©—θ­‰
4
+ εŸ·θ‘Œζ™‚ι–“ < 2 minοΌ›η’Ίθͺ deepseek_vl 可載ε…₯、processor 可處理 ChartQA
5
+ """
6
+ import sys, subprocess, logging
7
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s")
8
+ log = logging.getLogger(__name__)
9
+
10
+ # ─── Install deepseek_vl if missing ──────────────────────────────────────────
11
+ try:
12
+ from deepseek_vl.models import DeepseekVLV2Processor
13
+ except ImportError:
14
+ log.info("Installing deepseek_vl …")
15
+ subprocess.run(
16
+ [sys.executable, "-m", "pip", "install", "-q",
17
+ "git+https://github.com/deepseek-ai/DeepSeek-VL2.git"],
18
+ check=True,
19
+ )
20
+ from deepseek_vl.models import DeepseekVLV2Processor
21
+
22
+ import torch
23
+ from datasets import load_dataset
24
+ from PIL import Image
25
+ from transformers import AutoModelForCausalLM
26
+ from peft import LoraConfig, get_peft_model, TaskType
27
+
28
+ MODEL_ID = "deepseek-ai/deepseek-vl2-tiny"
29
+
30
+ # ─── 1. Processor ────────────────────────────────────────────────────────────
31
+ log.info("Loading processor …")
32
+ proc = DeepseekVLV2Processor.from_pretrained(MODEL_ID)
33
+ log.info("Processor OK βœ“")
34
+
35
+ # ─── 2. ChartQA mini sample ──────────────────────────────────────────────────
36
+ log.info("Loading 4 ChartQA samples …")
37
+ ds = load_dataset("HuggingFaceM4/ChartQA", split="val[:4]")
38
+ for row in ds:
39
+ img = row["image"]
40
+ if not isinstance(img, Image.Image):
41
+ img = Image.fromarray(img)
42
+ img = img.convert("RGB")
43
+ q = str(row["query"])
44
+ ans = row["label"][0] if isinstance(row["label"], list) else str(row["label"])
45
+ conv = [
46
+ {"role": "<|User|>", "content": f"<image>\n{q}", "images": [img]},
47
+ {"role": "<|Assistant|>", "content": ans},
48
+ ]
49
+ out = proc(conversations=[conv], images=[img], force_batchify=True, system_prompt="")
50
+ log.info(f" input_ids shape = {out['input_ids'].shape} query='{q[:40]}...'")
51
+ log.info("Processor + ChartQA collation OK βœ“")
52
+
53
+ # ─── 3. Model load + LoRA (no forward pass β€” saves time) ─────────────────────
54
+ log.info("Loading model (this takes ~1–2 min on first run) …")
55
+ model = AutoModelForCausalLM.from_pretrained(
56
+ MODEL_ID, trust_remote_code=True, torch_dtype=torch.bfloat16,
57
+ )
58
+ lora = LoraConfig(
59
+ task_type=TaskType.CAUSAL_LM, r=16, lora_alpha=32,
60
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
61
+ bias="none",
62
+ )
63
+ model = get_peft_model(model, lora)
64
+ model.print_trainable_parameters()
65
+ log.info("LoRA wrapping OK βœ“")
66
+
67
+ if torch.cuda.is_available():
68
+ model = model.to("cuda")
69
+ mem = torch.cuda.memory_reserved() / 1e9
70
+ log.info(f"VRAM reserved = {mem:.1f} GB")
71
+ if mem > 11.5:
72
+ log.warning("VRAM > 11.5 GB β€” training with batch=1 might OOM. "
73
+ "Try reducing MAX_TRAIN or set PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True")
74
+ else:
75
+ log.info("VRAM looks fine for batch_size=1 training βœ“")
76
+
77
+ log.info("=" * 50)
78
+ log.info("Smoke test PASSED β€” you can now run: python train_pipeline.py")
79
+ log.info("=" * 50)