Faaz commited on
Commit Β·
2ff5c54
1
Parent(s): 59c6c97
Day 3 COMPLETE: Full model architecture
Browse filesFiles added:
- src/model/architecture.py (Qwen2.5-Coder-7B + LoRA)
- src/model/vision_encoder.py (CLIP ViT-L/14, 256 tokens)
- src/model/fusion_layer.py (Linear+LayerNorm prepend)
- src/model/mindi_model.py (MINDI15 complete model)
- src/training/mindi_trainer.py (3-phase MI300X trainer)
- scripts/train.py (master training script)
- configs/training_config.yaml (MI300X config)
- setup_mi300x.sh (MI300X setup script)
- scripts/quality_filter.py, split_data.py, data_stats.py
- scripts/upload_everything_to_hf.py
Model: Qwen2.5-Coder-7B + LoRA + CLIP Vision
Ready for MI300X training!
- configs/training_config.yaml +96 -46
- data/processed/filter_report.json +54 -0
- data/processed/split_meta.json +66 -0
- scripts/data_stats.py +273 -0
- scripts/quality_filter.py +472 -0
- scripts/split_data.py +207 -0
- scripts/train.py +326 -17
- scripts/upload_everything_to_hf.py +354 -0
- setup_mi300x.sh +145 -0
- src/model/architecture.py +289 -0
- src/model/fusion_layer.py +221 -0
- src/model/mindi_model.py +620 -0
- src/model/vision_encoder.py +196 -37
- src/training/mindi_trainer.py +745 -0
configs/training_config.yaml
CHANGED
|
@@ -1,57 +1,107 @@
|
|
| 1 |
# ==========================================
|
| 2 |
# MINDI 1.5 Vision-Coder β Training Configuration
|
|
|
|
| 3 |
# ==========================================
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
training:
|
| 6 |
-
#
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
max_grad_norm: 1.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
#
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
save_total_limit: 5
|
| 30 |
-
checkpoint_dir: "./checkpoints"
|
| 31 |
-
resume_from_checkpoint: null
|
| 32 |
-
|
| 33 |
-
# Logging
|
| 34 |
-
logging_steps: 10
|
| 35 |
-
log_dir: "./logs/training"
|
| 36 |
-
report_to: "wandb"
|
| 37 |
-
|
| 38 |
-
# Evaluation
|
| 39 |
-
eval_strategy: "steps"
|
| 40 |
-
eval_steps: 250
|
| 41 |
-
eval_samples: 1000
|
| 42 |
-
|
| 43 |
-
# Memory optimization (for RTX 4060 local testing)
|
| 44 |
-
local_overrides:
|
| 45 |
-
batch_size: 1
|
| 46 |
-
gradient_accumulation_steps: 16
|
| 47 |
-
max_seq_length: 2048
|
| 48 |
-
gradient_checkpointing: true
|
| 49 |
-
optim: "adamw_8bit"
|
| 50 |
-
|
| 51 |
-
wandb:
|
| 52 |
-
project: "mindi-1.5-vision-coder"
|
| 53 |
-
entity: "mindigenous"
|
| 54 |
tags:
|
| 55 |
- "mindi-1.5"
|
| 56 |
- "lora"
|
| 57 |
- "vision-coder"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# ==========================================
|
| 2 |
# MINDI 1.5 Vision-Coder β Training Configuration
|
| 3 |
+
# Optimized for AMD MI300X 192GB VRAM
|
| 4 |
# ==========================================
|
| 5 |
|
| 6 |
+
# ββ Model ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 7 |
+
model:
|
| 8 |
+
name: "Qwen/Qwen2.5-Coder-7B-Instruct"
|
| 9 |
+
hidden_size: 4096
|
| 10 |
+
dtype: "bf16" # bf16 required for MI300X stability (NOT fp16)
|
| 11 |
+
use_compile: true # torch.compile() works on ROCm
|
| 12 |
+
gradient_checkpointing: true # Save VRAM even with 192GB
|
| 13 |
+
|
| 14 |
+
# ββ LoRA βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 15 |
+
lora:
|
| 16 |
+
r: 64
|
| 17 |
+
alpha: 128
|
| 18 |
+
dropout: 0.05
|
| 19 |
+
bias: "none"
|
| 20 |
+
task_type: "CAUSAL_LM"
|
| 21 |
+
target_modules:
|
| 22 |
+
- q_proj
|
| 23 |
+
- k_proj
|
| 24 |
+
- v_proj
|
| 25 |
+
- o_proj
|
| 26 |
+
- gate_proj
|
| 27 |
+
- up_proj
|
| 28 |
+
- down_proj
|
| 29 |
+
|
| 30 |
+
# ββ Vision βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 31 |
+
vision:
|
| 32 |
+
clip_model: "openai/clip-vit-large-patch14"
|
| 33 |
+
visual_tokens: 256 # 16Γ16 patches from ViT-L/14
|
| 34 |
+
projection_size: 4096 # Must match model.hidden_size
|
| 35 |
+
freeze_clip: true # Freeze CLIP backbone
|
| 36 |
+
|
| 37 |
+
# ββ Training Phases ββββββββββββββββββββββββββββββββββββββββββββ
|
| 38 |
training:
|
| 39 |
+
# Phase 1: LoRA only β teach coding patterns
|
| 40 |
+
phase1:
|
| 41 |
+
steps: 5000
|
| 42 |
+
lr: 2.0e-4
|
| 43 |
+
batch_size: 16 # MI300X can handle large batches
|
| 44 |
+
warmup_steps: 100
|
| 45 |
+
data_filter: "code_only"
|
| 46 |
+
|
| 47 |
+
# Phase 2: Vision bridge only β align visual tokens
|
| 48 |
+
phase2:
|
| 49 |
+
steps: 2500
|
| 50 |
+
lr: 1.0e-5
|
| 51 |
+
batch_size: 8 # Smaller batch for vision bridge
|
| 52 |
+
warmup_steps: 50
|
| 53 |
+
data_filter: "websight_only"
|
| 54 |
+
|
| 55 |
+
# Phase 3: All trainable β joint fine-tuning
|
| 56 |
+
phase3:
|
| 57 |
+
steps: 2500
|
| 58 |
+
lr: 5.0e-5
|
| 59 |
+
batch_size: 12
|
| 60 |
+
warmup_steps: 50
|
| 61 |
+
data_filter: "all"
|
| 62 |
+
|
| 63 |
+
# Shared training settings
|
| 64 |
+
grad_accumulation: 4
|
| 65 |
max_grad_norm: 1.0
|
| 66 |
+
eval_every: 250
|
| 67 |
+
save_every: 500
|
| 68 |
+
|
| 69 |
+
# ββ Data βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 70 |
+
data:
|
| 71 |
+
train_file: "data/processed/train.jsonl" # 4.18GB, 1,304,486 examples
|
| 72 |
+
val_file: "data/processed/val.jsonl" # 0.23GB, 72,471 examples
|
| 73 |
+
max_length: 4096
|
| 74 |
+
shuffle_buffer: 10000 # Streaming shuffle buffer size
|
| 75 |
+
num_workers: 4 # DataLoader workers
|
| 76 |
+
pin_memory: true
|
| 77 |
+
prefetch_factor: 2
|
| 78 |
|
| 79 |
+
# ββ Logging ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 80 |
+
logging:
|
| 81 |
+
wandb_project: "mindi-1.5-vision-coder"
|
| 82 |
+
wandb_entity: "mindigenous"
|
| 83 |
+
log_every: 10 # Log metrics every N steps
|
| 84 |
+
log_dir: "logs/training"
|
| 85 |
+
sample_every: 500 # Generate sample outputs every N steps
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
tags:
|
| 87 |
- "mindi-1.5"
|
| 88 |
- "lora"
|
| 89 |
- "vision-coder"
|
| 90 |
+
- "mi300x"
|
| 91 |
+
|
| 92 |
+
# ββ Output βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 93 |
+
output:
|
| 94 |
+
checkpoint_dir: "checkpoints/training"
|
| 95 |
+
best_model: "checkpoints/best"
|
| 96 |
+
hf_repo: "Mindigenous/MINDI-1.5-Vision-Coder"
|
| 97 |
+
push_every_phase: true
|
| 98 |
+
|
| 99 |
+
# ββ Local Dev Overrides (RTX 4060 8GB) ββββββββββββββββββββββββ
|
| 100 |
+
# Apply these when testing locally with --dry_run
|
| 101 |
+
local_overrides:
|
| 102 |
+
batch_size: 1
|
| 103 |
+
gradient_accumulation_steps: 16
|
| 104 |
+
max_length: 2048
|
| 105 |
+
gradient_checkpointing: true
|
| 106 |
+
use_compile: false
|
| 107 |
+
num_workers: 0
|
data/processed/filter_report.json
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"input_count": 1481497,
|
| 3 |
+
"kept_count": 1449428,
|
| 4 |
+
"rejected_count": 32069,
|
| 5 |
+
"kept_pct": 97.84,
|
| 6 |
+
"avg_tokens": 593.1,
|
| 7 |
+
"avg_quality": 6.487,
|
| 8 |
+
"total_tokens": 859694776,
|
| 9 |
+
"elapsed_seconds": 1023.6,
|
| 10 |
+
"filter_settings": {
|
| 11 |
+
"min_tokens": 50,
|
| 12 |
+
"max_tokens": 4096,
|
| 13 |
+
"min_quality": 5.0
|
| 14 |
+
},
|
| 15 |
+
"rejection_breakdown": {
|
| 16 |
+
"too_many_tokens": 30637,
|
| 17 |
+
"boilerplate_content": 1373,
|
| 18 |
+
"duplicate_content": 59
|
| 19 |
+
},
|
| 20 |
+
"source_kept": {
|
| 21 |
+
"codealpaca": 59241,
|
| 22 |
+
"codefeedback": 149865,
|
| 23 |
+
"websight": 250987,
|
| 24 |
+
"synthetic_nextjs": 90000,
|
| 25 |
+
"search_examples": 15000,
|
| 26 |
+
"sandbox_examples": 9000,
|
| 27 |
+
"starcoderdata": 569350,
|
| 28 |
+
"evol_code": 155998,
|
| 29 |
+
"magicoder": 149987
|
| 30 |
+
},
|
| 31 |
+
"source_rejected": {
|
| 32 |
+
"codealpaca": 741,
|
| 33 |
+
"codefeedback": 132,
|
| 34 |
+
"starcoderdata": 30650,
|
| 35 |
+
"evol_code": 518,
|
| 36 |
+
"magicoder": 13,
|
| 37 |
+
"websight": 15
|
| 38 |
+
},
|
| 39 |
+
"type_distribution": {
|
| 40 |
+
"code_generation": 1183441,
|
| 41 |
+
"vision_code": 250987,
|
| 42 |
+
"search": 15000
|
| 43 |
+
},
|
| 44 |
+
"language_distribution": {
|
| 45 |
+
"unknown": 490305,
|
| 46 |
+
"typescript": 375859,
|
| 47 |
+
"javascript": 298497,
|
| 48 |
+
"python": 211842,
|
| 49 |
+
"html": 36371,
|
| 50 |
+
"java": 32458,
|
| 51 |
+
"rust": 3709,
|
| 52 |
+
"go": 387
|
| 53 |
+
}
|
| 54 |
+
}
|
data/processed/split_meta.json
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"total": 1449428,
|
| 3 |
+
"train_count": 1304486,
|
| 4 |
+
"val_count": 72471,
|
| 5 |
+
"test_count": 72471,
|
| 6 |
+
"train_pct": 90.0,
|
| 7 |
+
"val_pct": 5.0,
|
| 8 |
+
"test_pct": 5.0,
|
| 9 |
+
"seed": 42,
|
| 10 |
+
"source_breakdown": {
|
| 11 |
+
"codealpaca": {
|
| 12 |
+
"total": 59241,
|
| 13 |
+
"train": 53317,
|
| 14 |
+
"val": 2962,
|
| 15 |
+
"test": 2962
|
| 16 |
+
},
|
| 17 |
+
"codefeedback": {
|
| 18 |
+
"total": 149865,
|
| 19 |
+
"train": 134879,
|
| 20 |
+
"val": 7493,
|
| 21 |
+
"test": 7493
|
| 22 |
+
},
|
| 23 |
+
"evol_code": {
|
| 24 |
+
"total": 155998,
|
| 25 |
+
"train": 140398,
|
| 26 |
+
"val": 7800,
|
| 27 |
+
"test": 7800
|
| 28 |
+
},
|
| 29 |
+
"magicoder": {
|
| 30 |
+
"total": 149987,
|
| 31 |
+
"train": 134989,
|
| 32 |
+
"val": 7499,
|
| 33 |
+
"test": 7499
|
| 34 |
+
},
|
| 35 |
+
"sandbox_examples": {
|
| 36 |
+
"total": 9000,
|
| 37 |
+
"train": 8100,
|
| 38 |
+
"val": 450,
|
| 39 |
+
"test": 450
|
| 40 |
+
},
|
| 41 |
+
"search_examples": {
|
| 42 |
+
"total": 15000,
|
| 43 |
+
"train": 13500,
|
| 44 |
+
"val": 750,
|
| 45 |
+
"test": 750
|
| 46 |
+
},
|
| 47 |
+
"starcoderdata": {
|
| 48 |
+
"total": 569350,
|
| 49 |
+
"train": 512414,
|
| 50 |
+
"val": 28468,
|
| 51 |
+
"test": 28468
|
| 52 |
+
},
|
| 53 |
+
"synthetic_nextjs": {
|
| 54 |
+
"total": 90000,
|
| 55 |
+
"train": 81000,
|
| 56 |
+
"val": 4500,
|
| 57 |
+
"test": 4500
|
| 58 |
+
},
|
| 59 |
+
"websight": {
|
| 60 |
+
"total": 250987,
|
| 61 |
+
"train": 225889,
|
| 62 |
+
"val": 12549,
|
| 63 |
+
"test": 12549
|
| 64 |
+
}
|
| 65 |
+
}
|
| 66 |
+
}
|
scripts/data_stats.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
MINDI 1.5 Vision-Coder β Dataset Statistics Report
|
| 4 |
+
|
| 5 |
+
Generates comprehensive statistics for the final train/val/test splits:
|
| 6 |
+
- Total counts and sizes
|
| 7 |
+
- Token distribution (min, max, mean, median, p95, p99)
|
| 8 |
+
- Quality score distribution
|
| 9 |
+
- Source breakdown
|
| 10 |
+
- Type breakdown
|
| 11 |
+
- Language breakdown
|
| 12 |
+
- Special token usage
|
| 13 |
+
|
| 14 |
+
Usage:
|
| 15 |
+
python scripts/data_stats.py # Full report
|
| 16 |
+
python scripts/data_stats.py --split train # Stats for train only
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from __future__ import annotations
|
| 20 |
+
|
| 21 |
+
import argparse
|
| 22 |
+
import json
|
| 23 |
+
import statistics
|
| 24 |
+
import sys
|
| 25 |
+
import time
|
| 26 |
+
from collections import Counter
|
| 27 |
+
from pathlib import Path
|
| 28 |
+
|
| 29 |
+
# ββ Paths βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 30 |
+
|
| 31 |
+
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
| 32 |
+
PROCESSED_DIR = PROJECT_ROOT / "data" / "processed"
|
| 33 |
+
|
| 34 |
+
SPLIT_FILES = {
|
| 35 |
+
"train": PROCESSED_DIR / "train.jsonl",
|
| 36 |
+
"val": PROCESSED_DIR / "val.jsonl",
|
| 37 |
+
"test": PROCESSED_DIR / "test.jsonl",
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
REPORT_FILE = PROCESSED_DIR / "dataset_stats.json"
|
| 41 |
+
|
| 42 |
+
# ββ Special tokens to check ββββββββββββββββββββββββββββββββββββββββββ
|
| 43 |
+
|
| 44 |
+
SPECIAL_TOKENS = [
|
| 45 |
+
"<|think_start|>", "<|think_end|>",
|
| 46 |
+
"<|code_start|>", "<|code_end|>",
|
| 47 |
+
"<|critique_start|>", "<|critique_end|>",
|
| 48 |
+
"<|suggest_start|>", "<|suggest_end|>",
|
| 49 |
+
"<|file_start|>", "<|file_end|>",
|
| 50 |
+
"<|search_start|>", "<|search_end|>",
|
| 51 |
+
"<|sandbox_start|>", "<|sandbox_end|>",
|
| 52 |
+
"<|vision_start|>", "<|vision_end|>",
|
| 53 |
+
"<|error_start|>", "<|error_end|>",
|
| 54 |
+
"<|fix_start|>", "<|fix_end|>",
|
| 55 |
+
]
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def percentile(sorted_data: list[int | float], p: float) -> float:
|
| 59 |
+
"""Calculate the p-th percentile from sorted data."""
|
| 60 |
+
if not sorted_data:
|
| 61 |
+
return 0.0
|
| 62 |
+
k = (len(sorted_data) - 1) * (p / 100.0)
|
| 63 |
+
f = int(k)
|
| 64 |
+
c = f + 1
|
| 65 |
+
if c >= len(sorted_data):
|
| 66 |
+
return float(sorted_data[f])
|
| 67 |
+
return sorted_data[f] + (k - f) * (sorted_data[c] - sorted_data[f])
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def compute_stats(file_path: Path, split_name: str) -> dict:
|
| 71 |
+
"""Compute statistics for a single split file."""
|
| 72 |
+
|
| 73 |
+
if not file_path.exists():
|
| 74 |
+
return {"error": f"File not found: {file_path}"}
|
| 75 |
+
|
| 76 |
+
tokens_list: list[int] = []
|
| 77 |
+
quality_list: list[float] = []
|
| 78 |
+
source_counts: Counter = Counter()
|
| 79 |
+
type_counts: Counter = Counter()
|
| 80 |
+
lang_counts: Counter = Counter()
|
| 81 |
+
framework_counts: Counter = Counter()
|
| 82 |
+
has_vision_count = 0
|
| 83 |
+
special_token_counts: Counter = Counter()
|
| 84 |
+
msg_count_dist: Counter = Counter() # number of messages per example
|
| 85 |
+
total_chars = 0
|
| 86 |
+
|
| 87 |
+
count = 0
|
| 88 |
+
with open(file_path, "r", encoding="utf-8") as f:
|
| 89 |
+
for line in f:
|
| 90 |
+
line = line.strip()
|
| 91 |
+
if not line:
|
| 92 |
+
continue
|
| 93 |
+
try:
|
| 94 |
+
ex = json.loads(line)
|
| 95 |
+
except json.JSONDecodeError:
|
| 96 |
+
continue
|
| 97 |
+
|
| 98 |
+
count += 1
|
| 99 |
+
meta = ex.get("metadata", {})
|
| 100 |
+
|
| 101 |
+
# Token count
|
| 102 |
+
tokens = meta.get("tokens", 0)
|
| 103 |
+
tokens_list.append(tokens)
|
| 104 |
+
|
| 105 |
+
# Quality score
|
| 106 |
+
quality = meta.get("quality_score", 0.0)
|
| 107 |
+
quality_list.append(quality)
|
| 108 |
+
|
| 109 |
+
# Source, type, language, framework
|
| 110 |
+
source_counts[ex.get("source", "unknown")] += 1
|
| 111 |
+
type_counts[ex.get("type", "unknown")] += 1
|
| 112 |
+
lang_counts[meta.get("language", "unknown")] += 1
|
| 113 |
+
framework_counts[meta.get("framework", "none")] += 1
|
| 114 |
+
|
| 115 |
+
# Vision
|
| 116 |
+
if meta.get("has_vision", False):
|
| 117 |
+
has_vision_count += 1
|
| 118 |
+
|
| 119 |
+
# Messages
|
| 120 |
+
messages = ex.get("messages", [])
|
| 121 |
+
msg_count_dist[len(messages)] += 1
|
| 122 |
+
|
| 123 |
+
# Special tokens in assistant content
|
| 124 |
+
for msg in messages:
|
| 125 |
+
if msg.get("role") == "assistant":
|
| 126 |
+
content = msg.get("content", "")
|
| 127 |
+
total_chars += len(content)
|
| 128 |
+
for tok in SPECIAL_TOKENS:
|
| 129 |
+
if tok in content:
|
| 130 |
+
special_token_counts[tok] += 1
|
| 131 |
+
|
| 132 |
+
# Sort for percentile computation
|
| 133 |
+
tokens_sorted = sorted(tokens_list)
|
| 134 |
+
quality_sorted = sorted(quality_list)
|
| 135 |
+
|
| 136 |
+
file_size_mb = file_path.stat().st_size / (1024 * 1024)
|
| 137 |
+
|
| 138 |
+
stats = {
|
| 139 |
+
"split": split_name,
|
| 140 |
+
"file": file_path.name,
|
| 141 |
+
"file_size_mb": round(file_size_mb, 1),
|
| 142 |
+
"count": count,
|
| 143 |
+
"total_tokens": sum(tokens_list),
|
| 144 |
+
"total_chars_assistant": total_chars,
|
| 145 |
+
"has_vision": has_vision_count,
|
| 146 |
+
"tokens": {
|
| 147 |
+
"min": min(tokens_sorted) if tokens_sorted else 0,
|
| 148 |
+
"max": max(tokens_sorted) if tokens_sorted else 0,
|
| 149 |
+
"mean": round(statistics.mean(tokens_list), 1) if tokens_list else 0,
|
| 150 |
+
"median": round(statistics.median(tokens_list), 1) if tokens_list else 0,
|
| 151 |
+
"stdev": round(statistics.stdev(tokens_list), 1) if len(tokens_list) > 1 else 0,
|
| 152 |
+
"p5": round(percentile(tokens_sorted, 5), 1),
|
| 153 |
+
"p25": round(percentile(tokens_sorted, 25), 1),
|
| 154 |
+
"p75": round(percentile(tokens_sorted, 75), 1),
|
| 155 |
+
"p95": round(percentile(tokens_sorted, 95), 1),
|
| 156 |
+
"p99": round(percentile(tokens_sorted, 99), 1),
|
| 157 |
+
},
|
| 158 |
+
"quality_score": {
|
| 159 |
+
"min": round(min(quality_sorted), 2) if quality_sorted else 0,
|
| 160 |
+
"max": round(max(quality_sorted), 2) if quality_sorted else 0,
|
| 161 |
+
"mean": round(statistics.mean(quality_list), 2) if quality_list else 0,
|
| 162 |
+
"median": round(statistics.median(quality_list), 2) if quality_list else 0,
|
| 163 |
+
},
|
| 164 |
+
"source_distribution": dict(source_counts.most_common()),
|
| 165 |
+
"type_distribution": dict(type_counts.most_common()),
|
| 166 |
+
"language_distribution": dict(lang_counts.most_common(30)),
|
| 167 |
+
"framework_distribution": dict(framework_counts.most_common(15)),
|
| 168 |
+
"messages_per_example": dict(sorted(msg_count_dist.items())),
|
| 169 |
+
"special_token_usage": dict(special_token_counts.most_common()),
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
return stats
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def print_stats(stats: dict) -> None:
|
| 176 |
+
"""Pretty-print statistics for a split."""
|
| 177 |
+
if "error" in stats:
|
| 178 |
+
print(f" ERROR: {stats['error']}")
|
| 179 |
+
return
|
| 180 |
+
|
| 181 |
+
print(f" Split: {stats['split']}")
|
| 182 |
+
print(f" File: {stats['file']} ({stats['file_size_mb']:.1f} MB)")
|
| 183 |
+
print(f" Count: {stats['count']:,}")
|
| 184 |
+
print(f" Total tokens: {stats['total_tokens']:,}")
|
| 185 |
+
print(f" Vision examples: {stats['has_vision']:,}")
|
| 186 |
+
print()
|
| 187 |
+
|
| 188 |
+
t = stats["tokens"]
|
| 189 |
+
print(f" Token distribution:")
|
| 190 |
+
print(f" Min: {t['min']:>8,} P5: {t['p5']:>8,.0f}")
|
| 191 |
+
print(f" P25: {t['p25']:>8,.0f} Median: {t['median']:>8,.0f}")
|
| 192 |
+
print(f" Mean: {t['mean']:>8,.0f} P75: {t['p75']:>8,.0f}")
|
| 193 |
+
print(f" P95: {t['p95']:>8,.0f} P99: {t['p99']:>8,.0f}")
|
| 194 |
+
print(f" Max: {t['max']:>8,} Stdev: {t['stdev']:>8,.0f}")
|
| 195 |
+
print()
|
| 196 |
+
|
| 197 |
+
q = stats["quality_score"]
|
| 198 |
+
print(f" Quality score: min={q['min']:.1f} mean={q['mean']:.1f} median={q['median']:.1f} max={q['max']:.1f}")
|
| 199 |
+
print()
|
| 200 |
+
|
| 201 |
+
print(f" Source distribution:")
|
| 202 |
+
for src, cnt in stats["source_distribution"].items():
|
| 203 |
+
pct = cnt / stats["count"] * 100
|
| 204 |
+
print(f" {src:<25s} {cnt:>10,} ({pct:5.1f}%)")
|
| 205 |
+
print()
|
| 206 |
+
|
| 207 |
+
print(f" Type distribution:")
|
| 208 |
+
for t_name, cnt in list(stats["type_distribution"].items())[:10]:
|
| 209 |
+
pct = cnt / stats["count"] * 100
|
| 210 |
+
print(f" {t_name:<25s} {cnt:>10,} ({pct:5.1f}%)")
|
| 211 |
+
print()
|
| 212 |
+
|
| 213 |
+
print(f" Language distribution (top 15):")
|
| 214 |
+
for lang, cnt in list(stats["language_distribution"].items())[:15]:
|
| 215 |
+
pct = cnt / stats["count"] * 100
|
| 216 |
+
print(f" {lang:<25s} {cnt:>10,} ({pct:5.1f}%)")
|
| 217 |
+
print()
|
| 218 |
+
|
| 219 |
+
if stats["special_token_usage"]:
|
| 220 |
+
print(f" Special token usage (examples containing token):")
|
| 221 |
+
for tok, cnt in stats["special_token_usage"].items():
|
| 222 |
+
pct = cnt / stats["count"] * 100
|
| 223 |
+
print(f" {tok:<25s} {cnt:>10,} ({pct:5.1f}%)")
|
| 224 |
+
print()
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def run_stats(split: str | None = None) -> None:
|
| 228 |
+
"""Generate and display statistics."""
|
| 229 |
+
start = time.time()
|
| 230 |
+
|
| 231 |
+
if split:
|
| 232 |
+
files = {split: SPLIT_FILES.get(split)}
|
| 233 |
+
if files[split] is None:
|
| 234 |
+
print(f"ERROR: Unknown split '{split}'. Choose from: {list(SPLIT_FILES.keys())}")
|
| 235 |
+
sys.exit(1)
|
| 236 |
+
else:
|
| 237 |
+
files = SPLIT_FILES
|
| 238 |
+
|
| 239 |
+
all_stats = {}
|
| 240 |
+
|
| 241 |
+
for name, path in files.items():
|
| 242 |
+
print("=" * 60)
|
| 243 |
+
print(f" Computing stats for: {name}")
|
| 244 |
+
print("=" * 60)
|
| 245 |
+
stats = compute_stats(path, name)
|
| 246 |
+
all_stats[name] = stats
|
| 247 |
+
print_stats(stats)
|
| 248 |
+
|
| 249 |
+
# Save JSON report
|
| 250 |
+
REPORT_FILE.parent.mkdir(parents=True, exist_ok=True)
|
| 251 |
+
with open(REPORT_FILE, "w", encoding="utf-8") as f:
|
| 252 |
+
json.dump(all_stats, f, indent=2)
|
| 253 |
+
print(f"Full report saved to: {REPORT_FILE.name}")
|
| 254 |
+
|
| 255 |
+
elapsed = time.time() - start
|
| 256 |
+
print(f"Stats generated in {elapsed:.1f}s")
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
# ββ CLI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 260 |
+
|
| 261 |
+
def main():
|
| 262 |
+
parser = argparse.ArgumentParser(
|
| 263 |
+
description="MINDI Dataset Statistics β comprehensive split analysis",
|
| 264 |
+
)
|
| 265 |
+
parser.add_argument("--split", type=str, choices=["train", "val", "test"],
|
| 266 |
+
help="Compute stats for a single split only")
|
| 267 |
+
|
| 268 |
+
args = parser.parse_args()
|
| 269 |
+
run_stats(split=args.split)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
if __name__ == "__main__":
|
| 273 |
+
main()
|
scripts/quality_filter.py
ADDED
|
@@ -0,0 +1,472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
MINDI 1.5 Vision-Coder β Quality Filter Pipeline
|
| 4 |
+
|
| 5 |
+
Filters mindi_all.jsonl to remove low-quality examples:
|
| 6 |
+
1. Token length filter β drop if <50 tokens or >4096 tokens
|
| 7 |
+
2. Duplicate detection β SHA-256 hash of assistant content
|
| 8 |
+
3. JSON structure check β valid schema with required fields
|
| 9 |
+
4. Special token check β assistant must have code_start/code_end pair
|
| 10 |
+
5. Quality score filter β keep only quality_score >= 5.0
|
| 11 |
+
6. Content heuristics β drop empty/trivial/boilerplate responses
|
| 12 |
+
|
| 13 |
+
Usage:
|
| 14 |
+
python scripts/quality_filter.py # Full run
|
| 15 |
+
python scripts/quality_filter.py --dry-run # Preview only
|
| 16 |
+
python scripts/quality_filter.py --min-tokens 100 # Custom min tokens
|
| 17 |
+
python scripts/quality_filter.py --max-tokens 8192 # Custom max tokens
|
| 18 |
+
python scripts/quality_filter.py --min-quality 7.0 # Stricter quality
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
from __future__ import annotations
|
| 22 |
+
|
| 23 |
+
import argparse
|
| 24 |
+
import hashlib
|
| 25 |
+
import json
|
| 26 |
+
import sys
|
| 27 |
+
import time
|
| 28 |
+
from collections import Counter, defaultdict
|
| 29 |
+
from pathlib import Path
|
| 30 |
+
|
| 31 |
+
# ββ Paths βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 32 |
+
|
| 33 |
+
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
| 34 |
+
INPUT_FILE = PROJECT_ROOT / "data" / "processed" / "mindi_all.jsonl"
|
| 35 |
+
OUTPUT_FILE = PROJECT_ROOT / "data" / "processed" / "mindi_filtered.jsonl"
|
| 36 |
+
REJECT_FILE = PROJECT_ROOT / "data" / "processed" / "mindi_rejected.jsonl"
|
| 37 |
+
REPORT_FILE = PROJECT_ROOT / "data" / "processed" / "filter_report.json"
|
| 38 |
+
|
| 39 |
+
# ββ Required schema fields ββββββββββββββββββββββββββββββββββββββββββββ
|
| 40 |
+
|
| 41 |
+
REQUIRED_FIELDS = {"id", "type", "source", "messages", "metadata"}
|
| 42 |
+
REQUIRED_METADATA = {"language", "tokens"}
|
| 43 |
+
VALID_ROLES = {"system", "user", "assistant"}
|
| 44 |
+
|
| 45 |
+
# ββ Protected sources (hand-crafted gold data β lighter filtering) βββββ
|
| 46 |
+
|
| 47 |
+
PROTECTED_SOURCES = {"sandbox_examples", "search_examples", "synthetic_nextjs"}
|
| 48 |
+
|
| 49 |
+
# ββ MINDI agentic token scoring bonuses βββββββββββββββββββββββββββββββ
|
| 50 |
+
# Examples with these tokens teach the model to be an *agent*.
|
| 51 |
+
# Each occurrence adds to the quality_score before the threshold.
|
| 52 |
+
|
| 53 |
+
MINDI_TOKEN_BONUSES = {
|
| 54 |
+
"<|think_start|>": 2.0,
|
| 55 |
+
"<|search_start|>": 3.0,
|
| 56 |
+
"<|error_start|>": 3.0,
|
| 57 |
+
"<|sandbox_start|>": 3.0,
|
| 58 |
+
"<|critique_start|>": 2.0,
|
| 59 |
+
"<|suggest_start|>": 1.0,
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
# ββ Special token pairs that assistant messages should contain βββββββββ
|
| 63 |
+
|
| 64 |
+
CODE_TOKEN_PAIRS = [
|
| 65 |
+
("<|code_start|>", "<|code_end|>"),
|
| 66 |
+
]
|
| 67 |
+
|
| 68 |
+
# At least one of these pairs should be present in assistant content
|
| 69 |
+
OPTIONAL_TOKEN_PAIRS = [
|
| 70 |
+
("<|think_start|>", "<|think_end|>"),
|
| 71 |
+
("<|critique_start|>", "<|critique_end|>"),
|
| 72 |
+
("<|suggest_start|>", "<|suggest_end|>"),
|
| 73 |
+
("<|file_start|>", "<|file_end|>"),
|
| 74 |
+
("<|search_start|>", "<|search_end|>"),
|
| 75 |
+
("<|sandbox_start|>", "<|sandbox_end|>"),
|
| 76 |
+
("<|error_start|>", "<|error_end|>"),
|
| 77 |
+
("<|fix_start|>", "<|fix_end|>"),
|
| 78 |
+
]
|
| 79 |
+
|
| 80 |
+
# ββ Rejection reasons βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 81 |
+
|
| 82 |
+
class Reason:
|
| 83 |
+
INVALID_JSON = "invalid_json"
|
| 84 |
+
MISSING_FIELDS = "missing_fields"
|
| 85 |
+
MISSING_METADATA = "missing_metadata"
|
| 86 |
+
NO_MESSAGES = "no_messages"
|
| 87 |
+
BAD_ROLES = "bad_message_roles"
|
| 88 |
+
NO_ASSISTANT = "no_assistant_message"
|
| 89 |
+
EMPTY_ASSISTANT = "empty_assistant_content"
|
| 90 |
+
TOO_SHORT = "too_few_tokens"
|
| 91 |
+
TOO_LONG = "too_many_tokens"
|
| 92 |
+
DUPLICATE = "duplicate_content"
|
| 93 |
+
LOW_QUALITY = "low_quality_score"
|
| 94 |
+
NO_CODE_TOKENS = "missing_code_tokens"
|
| 95 |
+
BOILERPLATE = "boilerplate_content"
|
| 96 |
+
UNMATCHED_TOKENS = "unmatched_special_tokens"
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
# ββ Filter functions ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 100 |
+
|
| 101 |
+
def validate_schema(example: dict) -> str | None:
|
| 102 |
+
"""Check required fields and structure. Returns rejection reason or None."""
|
| 103 |
+
# Top-level fields
|
| 104 |
+
missing = REQUIRED_FIELDS - set(example.keys())
|
| 105 |
+
if missing:
|
| 106 |
+
return Reason.MISSING_FIELDS
|
| 107 |
+
|
| 108 |
+
# Metadata fields
|
| 109 |
+
meta = example.get("metadata", {})
|
| 110 |
+
if not isinstance(meta, dict):
|
| 111 |
+
return Reason.MISSING_METADATA
|
| 112 |
+
missing_meta = REQUIRED_METADATA - set(meta.keys())
|
| 113 |
+
if missing_meta:
|
| 114 |
+
return Reason.MISSING_METADATA
|
| 115 |
+
|
| 116 |
+
# Messages array
|
| 117 |
+
messages = example.get("messages", [])
|
| 118 |
+
if not isinstance(messages, list) or len(messages) == 0:
|
| 119 |
+
return Reason.NO_MESSAGES
|
| 120 |
+
|
| 121 |
+
# Role validation
|
| 122 |
+
for msg in messages:
|
| 123 |
+
if not isinstance(msg, dict) or "role" not in msg or "content" not in msg:
|
| 124 |
+
return Reason.BAD_ROLES
|
| 125 |
+
if msg["role"] not in VALID_ROLES:
|
| 126 |
+
return Reason.BAD_ROLES
|
| 127 |
+
|
| 128 |
+
return None
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def get_assistant_content(example: dict) -> str:
|
| 132 |
+
"""Extract concatenated assistant message content."""
|
| 133 |
+
parts = []
|
| 134 |
+
for msg in example.get("messages", []):
|
| 135 |
+
if msg.get("role") == "assistant":
|
| 136 |
+
parts.append(msg.get("content", ""))
|
| 137 |
+
return "\n".join(parts)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def check_assistant_exists(example: dict) -> str | None:
|
| 141 |
+
"""Must have at least one assistant message with non-empty content."""
|
| 142 |
+
content = get_assistant_content(example)
|
| 143 |
+
if not content:
|
| 144 |
+
return Reason.NO_ASSISTANT
|
| 145 |
+
if len(content.strip()) < 10:
|
| 146 |
+
return Reason.EMPTY_ASSISTANT
|
| 147 |
+
return None
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def check_token_length(example: dict, min_tokens: int, max_tokens: int) -> str | None:
|
| 151 |
+
"""Filter by token count stored in metadata."""
|
| 152 |
+
tokens = example.get("metadata", {}).get("tokens", 0)
|
| 153 |
+
if tokens < min_tokens:
|
| 154 |
+
return Reason.TOO_SHORT
|
| 155 |
+
if tokens > max_tokens:
|
| 156 |
+
return Reason.TOO_LONG
|
| 157 |
+
return None
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def compute_mindi_bonus(example: dict) -> float:
|
| 161 |
+
"""Compute bonus score for MINDI agentic special tokens."""
|
| 162 |
+
content = get_assistant_content(example)
|
| 163 |
+
bonus = 0.0
|
| 164 |
+
for token, value in MINDI_TOKEN_BONUSES.items():
|
| 165 |
+
if token in content:
|
| 166 |
+
bonus += value
|
| 167 |
+
return bonus
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def check_quality_score(example: dict, min_quality: float) -> str | None:
|
| 171 |
+
"""Filter by quality_score + MINDI token bonus."""
|
| 172 |
+
score = example.get("metadata", {}).get("quality_score", 0.0)
|
| 173 |
+
score += compute_mindi_bonus(example)
|
| 174 |
+
if score < min_quality:
|
| 175 |
+
return Reason.LOW_QUALITY
|
| 176 |
+
return None
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def check_code_tokens(example: dict) -> str | None:
|
| 180 |
+
"""Assistant content must contain code_start/code_end pair."""
|
| 181 |
+
content = get_assistant_content(example)
|
| 182 |
+
|
| 183 |
+
for start_tok, end_tok in CODE_TOKEN_PAIRS:
|
| 184 |
+
if start_tok in content and end_tok in content:
|
| 185 |
+
# Check ordering: start before end
|
| 186 |
+
if content.index(start_tok) < content.rindex(end_tok):
|
| 187 |
+
return None # OK
|
| 188 |
+
|
| 189 |
+
return Reason.NO_CODE_TOKENS
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def check_unmatched_tokens(example: dict) -> str | None:
|
| 193 |
+
"""Ensure all special token pairs are properly matched (start count == end count)."""
|
| 194 |
+
content = get_assistant_content(example)
|
| 195 |
+
all_pairs = CODE_TOKEN_PAIRS + OPTIONAL_TOKEN_PAIRS
|
| 196 |
+
|
| 197 |
+
for start_tok, end_tok in all_pairs:
|
| 198 |
+
start_count = content.count(start_tok)
|
| 199 |
+
end_count = content.count(end_tok)
|
| 200 |
+
if start_count != end_count:
|
| 201 |
+
return Reason.UNMATCHED_TOKENS
|
| 202 |
+
|
| 203 |
+
return None
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def check_boilerplate(example: dict) -> str | None:
|
| 207 |
+
"""Detect boilerplate/placeholder assistant responses."""
|
| 208 |
+
content = get_assistant_content(example)
|
| 209 |
+
content_lower = content.lower().strip()
|
| 210 |
+
|
| 211 |
+
# Very short code blocks (just placeholder)
|
| 212 |
+
code_markers = ("<|code_start|>", "<|code_end|>")
|
| 213 |
+
if code_markers[0] in content and code_markers[1] in content:
|
| 214 |
+
start_idx = content.index(code_markers[0]) + len(code_markers[0])
|
| 215 |
+
end_idx = content.index(code_markers[1])
|
| 216 |
+
code_body = content[start_idx:end_idx].strip()
|
| 217 |
+
if len(code_body) < 5:
|
| 218 |
+
return Reason.BOILERPLATE
|
| 219 |
+
|
| 220 |
+
# Repetitive content (same char repeated)
|
| 221 |
+
stripped = content_lower.replace(" ", "").replace("\n", "")
|
| 222 |
+
if len(stripped) > 20:
|
| 223 |
+
unique_chars = len(set(stripped))
|
| 224 |
+
if unique_chars < 5:
|
| 225 |
+
return Reason.BOILERPLATE
|
| 226 |
+
|
| 227 |
+
return None
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def content_hash(example: dict) -> str:
|
| 231 |
+
"""SHA-256 hash of assistant content for deduplication."""
|
| 232 |
+
content = get_assistant_content(example)
|
| 233 |
+
return hashlib.sha256(content.encode("utf-8", errors="replace")).hexdigest()
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
# ββ Main pipeline βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 237 |
+
|
| 238 |
+
def run_filter(
|
| 239 |
+
dry_run: bool = False,
|
| 240 |
+
min_tokens: int = 50,
|
| 241 |
+
max_tokens: int = 4096,
|
| 242 |
+
min_quality: float = 5.0,
|
| 243 |
+
) -> None:
|
| 244 |
+
"""Run the full quality filter pipeline."""
|
| 245 |
+
|
| 246 |
+
if not INPUT_FILE.exists():
|
| 247 |
+
print(f"ERROR: Input file not found: {INPUT_FILE}")
|
| 248 |
+
sys.exit(1)
|
| 249 |
+
|
| 250 |
+
# Count input lines
|
| 251 |
+
print(f"Counting input examples from {INPUT_FILE.name} ...")
|
| 252 |
+
total_input = sum(1 for _ in open(INPUT_FILE, "r", encoding="utf-8"))
|
| 253 |
+
print(f" Total input: {total_input:,} examples")
|
| 254 |
+
print()
|
| 255 |
+
|
| 256 |
+
# Filter settings
|
| 257 |
+
print("Filter settings:")
|
| 258 |
+
print(f" Min tokens: {min_tokens}")
|
| 259 |
+
print(f" Max tokens: {max_tokens}")
|
| 260 |
+
print(f" Min quality: {min_quality}")
|
| 261 |
+
print(f" Dry run: {dry_run}")
|
| 262 |
+
print()
|
| 263 |
+
|
| 264 |
+
# Stats tracking
|
| 265 |
+
kept = 0
|
| 266 |
+
rejected = 0
|
| 267 |
+
reject_reasons: Counter = Counter()
|
| 268 |
+
source_kept: Counter = Counter()
|
| 269 |
+
source_rejected: Counter = Counter()
|
| 270 |
+
seen_hashes: set[str] = set()
|
| 271 |
+
token_sum = 0
|
| 272 |
+
quality_sum = 0.0
|
| 273 |
+
|
| 274 |
+
# Type distribution
|
| 275 |
+
type_counts: Counter = Counter()
|
| 276 |
+
|
| 277 |
+
# Language distribution
|
| 278 |
+
lang_counts: Counter = Counter()
|
| 279 |
+
|
| 280 |
+
start_time = time.time()
|
| 281 |
+
|
| 282 |
+
out_f = None
|
| 283 |
+
rej_f = None
|
| 284 |
+
if not dry_run:
|
| 285 |
+
OUTPUT_FILE.parent.mkdir(parents=True, exist_ok=True)
|
| 286 |
+
out_f = open(OUTPUT_FILE, "w", encoding="utf-8")
|
| 287 |
+
rej_f = open(REJECT_FILE, "w", encoding="utf-8")
|
| 288 |
+
|
| 289 |
+
try:
|
| 290 |
+
with open(INPUT_FILE, "r", encoding="utf-8") as f:
|
| 291 |
+
for line_num, line in enumerate(f, 1):
|
| 292 |
+
line = line.strip()
|
| 293 |
+
if not line:
|
| 294 |
+
continue
|
| 295 |
+
|
| 296 |
+
# Parse JSON
|
| 297 |
+
try:
|
| 298 |
+
example = json.loads(line)
|
| 299 |
+
except json.JSONDecodeError:
|
| 300 |
+
reject_reasons[Reason.INVALID_JSON] += 1
|
| 301 |
+
rejected += 1
|
| 302 |
+
if rej_f:
|
| 303 |
+
rej_f.write(line + "\n")
|
| 304 |
+
continue
|
| 305 |
+
|
| 306 |
+
source = example.get("source", "unknown")
|
| 307 |
+
is_protected = source in PROTECTED_SOURCES
|
| 308 |
+
|
| 309 |
+
# Run filter chain (order matters: cheapest first)
|
| 310 |
+
# Protected sources: schema + assistant + token length + unmatched only
|
| 311 |
+
# Regular sources: full chain + dedup
|
| 312 |
+
if is_protected:
|
| 313 |
+
rejection = (
|
| 314 |
+
validate_schema(example)
|
| 315 |
+
or check_assistant_exists(example)
|
| 316 |
+
or check_token_length(example, min_tokens, max_tokens)
|
| 317 |
+
or check_unmatched_tokens(example)
|
| 318 |
+
)
|
| 319 |
+
else:
|
| 320 |
+
rejection = (
|
| 321 |
+
validate_schema(example)
|
| 322 |
+
or check_assistant_exists(example)
|
| 323 |
+
or check_token_length(example, min_tokens, max_tokens)
|
| 324 |
+
or check_quality_score(example, min_quality)
|
| 325 |
+
or check_code_tokens(example)
|
| 326 |
+
or check_unmatched_tokens(example)
|
| 327 |
+
or check_boilerplate(example)
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
if rejection is None and not is_protected:
|
| 331 |
+
# Dedup check (skip for protected sources)
|
| 332 |
+
h = content_hash(example)
|
| 333 |
+
if h in seen_hashes:
|
| 334 |
+
rejection = Reason.DUPLICATE
|
| 335 |
+
|
| 336 |
+
if rejection is not None:
|
| 337 |
+
reject_reasons[rejection] += 1
|
| 338 |
+
source_rejected[source] += 1
|
| 339 |
+
rejected += 1
|
| 340 |
+
if rej_f:
|
| 341 |
+
rej_f.write(line + "\n")
|
| 342 |
+
continue
|
| 343 |
+
|
| 344 |
+
# Passed all filters
|
| 345 |
+
if not is_protected:
|
| 346 |
+
seen_hashes.add(h)
|
| 347 |
+
kept += 1
|
| 348 |
+
source_kept[source] += 1
|
| 349 |
+
token_sum += example.get("metadata", {}).get("tokens", 0)
|
| 350 |
+
quality_sum += example.get("metadata", {}).get("quality_score", 0.0)
|
| 351 |
+
type_counts[example.get("type", "unknown")] += 1
|
| 352 |
+
lang_counts[example.get("metadata", {}).get("language", "unknown")] += 1
|
| 353 |
+
|
| 354 |
+
if out_f:
|
| 355 |
+
out_f.write(line + "\n")
|
| 356 |
+
|
| 357 |
+
# Progress
|
| 358 |
+
if line_num % 50000 == 0:
|
| 359 |
+
elapsed = time.time() - start_time
|
| 360 |
+
rate = line_num / elapsed if elapsed > 0 else 0
|
| 361 |
+
pct = (line_num / total_input) * 100
|
| 362 |
+
print(f" [{pct:5.1f}%] Processed {line_num:>10,} | Kept {kept:>10,} | Rejected {rejected:>10,} | {rate:,.0f} ex/s")
|
| 363 |
+
|
| 364 |
+
finally:
|
| 365 |
+
if out_f:
|
| 366 |
+
out_f.close()
|
| 367 |
+
if rej_f:
|
| 368 |
+
rej_f.close()
|
| 369 |
+
|
| 370 |
+
elapsed = time.time() - start_time
|
| 371 |
+
|
| 372 |
+
# ββ Summary report ββββββββββββββββββββββββββββββββββββββββββββ
|
| 373 |
+
print()
|
| 374 |
+
print("=" * 60)
|
| 375 |
+
print(" QUALITY FILTER REPORT")
|
| 376 |
+
print("=" * 60)
|
| 377 |
+
print(f" Input: {total_input:>10,} examples")
|
| 378 |
+
print(f" Kept: {kept:>10,} examples ({kept/total_input*100:.1f}%)")
|
| 379 |
+
print(f" Rejected: {rejected:>10,} examples ({rejected/total_input*100:.1f}%)")
|
| 380 |
+
print(f" Time: {elapsed:>10.1f} seconds")
|
| 381 |
+
print(f" Rate: {total_input/elapsed:>10,.0f} examples/sec")
|
| 382 |
+
print()
|
| 383 |
+
|
| 384 |
+
if kept > 0:
|
| 385 |
+
print(f" Avg tokens: {token_sum/kept:>10.0f}")
|
| 386 |
+
print(f" Avg quality: {quality_sum/kept:>10.2f}")
|
| 387 |
+
print(f" Total tokens:{token_sum:>10,}")
|
| 388 |
+
print()
|
| 389 |
+
|
| 390 |
+
# Rejection breakdown
|
| 391 |
+
print(" Rejection breakdown:")
|
| 392 |
+
for reason, count in reject_reasons.most_common():
|
| 393 |
+
pct = count / total_input * 100
|
| 394 |
+
print(f" {reason:<30s} {count:>10,} ({pct:.1f}%)")
|
| 395 |
+
print()
|
| 396 |
+
|
| 397 |
+
# Source breakdown
|
| 398 |
+
print(" Source breakdown (kept / total):")
|
| 399 |
+
all_sources = sorted(set(list(source_kept.keys()) + list(source_rejected.keys())))
|
| 400 |
+
for src in all_sources:
|
| 401 |
+
k = source_kept.get(src, 0)
|
| 402 |
+
total = k + source_rejected.get(src, 0)
|
| 403 |
+
pct = k / total * 100 if total > 0 else 0
|
| 404 |
+
print(f" {src:<25s} {k:>8,} / {total:>8,} ({pct:.1f}%)")
|
| 405 |
+
print()
|
| 406 |
+
|
| 407 |
+
# Type distribution
|
| 408 |
+
print(" Type distribution (kept):")
|
| 409 |
+
for t, c in type_counts.most_common(10):
|
| 410 |
+
print(f" {t:<25s} {c:>8,}")
|
| 411 |
+
print()
|
| 412 |
+
|
| 413 |
+
# Language distribution (top 15)
|
| 414 |
+
print(" Language distribution (kept, top 15):")
|
| 415 |
+
for lang, c in lang_counts.most_common(15):
|
| 416 |
+
print(f" {lang:<25s} {c:>8,}")
|
| 417 |
+
print()
|
| 418 |
+
|
| 419 |
+
if not dry_run:
|
| 420 |
+
print(f" Output: {OUTPUT_FILE}")
|
| 421 |
+
print(f" Rejects: {REJECT_FILE}")
|
| 422 |
+
|
| 423 |
+
# Save machine-readable report
|
| 424 |
+
report = {
|
| 425 |
+
"input_count": total_input,
|
| 426 |
+
"kept_count": kept,
|
| 427 |
+
"rejected_count": rejected,
|
| 428 |
+
"kept_pct": round(kept / total_input * 100, 2),
|
| 429 |
+
"avg_tokens": round(token_sum / kept, 1) if kept > 0 else 0,
|
| 430 |
+
"avg_quality": round(quality_sum / kept, 3) if kept > 0 else 0,
|
| 431 |
+
"total_tokens": token_sum,
|
| 432 |
+
"elapsed_seconds": round(elapsed, 1),
|
| 433 |
+
"filter_settings": {
|
| 434 |
+
"min_tokens": min_tokens,
|
| 435 |
+
"max_tokens": max_tokens,
|
| 436 |
+
"min_quality": min_quality,
|
| 437 |
+
},
|
| 438 |
+
"rejection_breakdown": dict(reject_reasons.most_common()),
|
| 439 |
+
"source_kept": dict(source_kept),
|
| 440 |
+
"source_rejected": dict(source_rejected),
|
| 441 |
+
"type_distribution": dict(type_counts.most_common()),
|
| 442 |
+
"language_distribution": dict(lang_counts.most_common(30)),
|
| 443 |
+
}
|
| 444 |
+
with open(REPORT_FILE, "w", encoding="utf-8") as rf:
|
| 445 |
+
json.dump(report, rf, indent=2)
|
| 446 |
+
print(f" Report: {REPORT_FILE}")
|
| 447 |
+
|
| 448 |
+
print("=" * 60)
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
# ββ CLI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 452 |
+
|
| 453 |
+
def main():
|
| 454 |
+
parser = argparse.ArgumentParser(
|
| 455 |
+
description="MINDI Quality Filter β remove low-quality training examples",
|
| 456 |
+
)
|
| 457 |
+
parser.add_argument("--dry-run", action="store_true", help="Preview counts without writing output")
|
| 458 |
+
parser.add_argument("--min-tokens", type=int, default=50, help="Minimum token count (default: 50)")
|
| 459 |
+
parser.add_argument("--max-tokens", type=int, default=4096, help="Maximum token count (default: 4096)")
|
| 460 |
+
parser.add_argument("--min-quality", type=float, default=5.0, help="Minimum quality_score (default: 5.0)")
|
| 461 |
+
|
| 462 |
+
args = parser.parse_args()
|
| 463 |
+
run_filter(
|
| 464 |
+
dry_run=args.dry_run,
|
| 465 |
+
min_tokens=args.min_tokens,
|
| 466 |
+
max_tokens=args.max_tokens,
|
| 467 |
+
min_quality=args.min_quality,
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
if __name__ == "__main__":
|
| 472 |
+
main()
|
scripts/split_data.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
MINDI 1.5 Vision-Coder β Train / Validation / Test Split
|
| 4 |
+
|
| 5 |
+
Splits mindi_filtered.jsonl into:
|
| 6 |
+
- train.jsonl (90%)
|
| 7 |
+
- val.jsonl (5%)
|
| 8 |
+
- test.jsonl (5%)
|
| 9 |
+
|
| 10 |
+
Stratified by source to ensure proportional representation.
|
| 11 |
+
Deterministic with a fixed random seed.
|
| 12 |
+
|
| 13 |
+
Usage:
|
| 14 |
+
python scripts/split_data.py # Default 90/5/5
|
| 15 |
+
python scripts/split_data.py --train 0.85 --val 0.10 --test 0.05
|
| 16 |
+
python scripts/split_data.py --seed 42
|
| 17 |
+
python scripts/split_data.py --dry-run
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
from __future__ import annotations
|
| 21 |
+
|
| 22 |
+
import argparse
|
| 23 |
+
import json
|
| 24 |
+
import random
|
| 25 |
+
import sys
|
| 26 |
+
import time
|
| 27 |
+
from collections import Counter
|
| 28 |
+
from pathlib import Path
|
| 29 |
+
|
| 30 |
+
# ββ Paths βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 31 |
+
|
| 32 |
+
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
| 33 |
+
INPUT_FILE = PROJECT_ROOT / "data" / "processed" / "mindi_filtered.jsonl"
|
| 34 |
+
OUTPUT_DIR = PROJECT_ROOT / "data" / "processed"
|
| 35 |
+
TRAIN_FILE = OUTPUT_DIR / "train.jsonl"
|
| 36 |
+
VAL_FILE = OUTPUT_DIR / "val.jsonl"
|
| 37 |
+
TEST_FILE = OUTPUT_DIR / "test.jsonl"
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def run_split(
|
| 41 |
+
train_ratio: float = 0.90,
|
| 42 |
+
val_ratio: float = 0.05,
|
| 43 |
+
test_ratio: float = 0.05,
|
| 44 |
+
seed: int = 42,
|
| 45 |
+
dry_run: bool = False,
|
| 46 |
+
) -> None:
|
| 47 |
+
"""Split filtered data into train/val/test with stratification by source."""
|
| 48 |
+
|
| 49 |
+
# Validate ratios
|
| 50 |
+
total_ratio = train_ratio + val_ratio + test_ratio
|
| 51 |
+
if abs(total_ratio - 1.0) > 0.001:
|
| 52 |
+
print(f"ERROR: Ratios must sum to 1.0, got {total_ratio:.3f}")
|
| 53 |
+
sys.exit(1)
|
| 54 |
+
|
| 55 |
+
if not INPUT_FILE.exists():
|
| 56 |
+
print(f"ERROR: Input file not found: {INPUT_FILE}")
|
| 57 |
+
print(" Run quality_filter.py first to generate mindi_filtered.jsonl")
|
| 58 |
+
sys.exit(1)
|
| 59 |
+
|
| 60 |
+
print(f"Loading examples from {INPUT_FILE.name} ...")
|
| 61 |
+
start = time.time()
|
| 62 |
+
|
| 63 |
+
# Group lines by source for stratified splitting
|
| 64 |
+
source_lines: dict[str, list[str]] = {}
|
| 65 |
+
total = 0
|
| 66 |
+
with open(INPUT_FILE, "r", encoding="utf-8") as f:
|
| 67 |
+
for line in f:
|
| 68 |
+
line = line.strip()
|
| 69 |
+
if not line:
|
| 70 |
+
continue
|
| 71 |
+
total += 1
|
| 72 |
+
try:
|
| 73 |
+
example = json.loads(line)
|
| 74 |
+
source = example.get("source", "unknown")
|
| 75 |
+
except json.JSONDecodeError:
|
| 76 |
+
source = "unknown"
|
| 77 |
+
source_lines.setdefault(source, []).append(line)
|
| 78 |
+
|
| 79 |
+
load_time = time.time() - start
|
| 80 |
+
print(f" Loaded {total:,} examples in {load_time:.1f}s")
|
| 81 |
+
print(f" Sources: {len(source_lines)}")
|
| 82 |
+
print()
|
| 83 |
+
|
| 84 |
+
# Split settings
|
| 85 |
+
print(f"Split ratios: train={train_ratio:.0%} / val={val_ratio:.0%} / test={test_ratio:.0%}")
|
| 86 |
+
print(f"Random seed: {seed}")
|
| 87 |
+
print(f"Dry run: {dry_run}")
|
| 88 |
+
print()
|
| 89 |
+
|
| 90 |
+
rng = random.Random(seed)
|
| 91 |
+
|
| 92 |
+
train_lines: list[str] = []
|
| 93 |
+
val_lines: list[str] = []
|
| 94 |
+
test_lines: list[str] = []
|
| 95 |
+
|
| 96 |
+
source_stats: dict[str, dict[str, int]] = {}
|
| 97 |
+
|
| 98 |
+
for source in sorted(source_lines.keys()):
|
| 99 |
+
lines = source_lines[source]
|
| 100 |
+
rng.shuffle(lines)
|
| 101 |
+
|
| 102 |
+
n = len(lines)
|
| 103 |
+
n_val = max(1, round(n * val_ratio)) if n >= 3 else 0
|
| 104 |
+
n_test = max(1, round(n * test_ratio)) if n >= 3 else 0
|
| 105 |
+
n_train = n - n_val - n_test
|
| 106 |
+
|
| 107 |
+
# Edge case: if too few examples, put all in train
|
| 108 |
+
if n < 3:
|
| 109 |
+
n_train = n
|
| 110 |
+
n_val = 0
|
| 111 |
+
n_test = 0
|
| 112 |
+
|
| 113 |
+
train_lines.extend(lines[:n_train])
|
| 114 |
+
val_lines.extend(lines[n_train:n_train + n_val])
|
| 115 |
+
test_lines.extend(lines[n_train + n_val:])
|
| 116 |
+
|
| 117 |
+
source_stats[source] = {
|
| 118 |
+
"total": n,
|
| 119 |
+
"train": n_train,
|
| 120 |
+
"val": n_val,
|
| 121 |
+
"test": n_test,
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
# Shuffle final lists (so sources are interleaved)
|
| 125 |
+
rng.shuffle(train_lines)
|
| 126 |
+
rng.shuffle(val_lines)
|
| 127 |
+
rng.shuffle(test_lines)
|
| 128 |
+
|
| 129 |
+
# ββ Summary βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 130 |
+
print("=" * 60)
|
| 131 |
+
print(" SPLIT SUMMARY")
|
| 132 |
+
print("=" * 60)
|
| 133 |
+
print(f" Total: {total:>10,}")
|
| 134 |
+
print(f" Train: {len(train_lines):>10,} ({len(train_lines)/total*100:.1f}%)")
|
| 135 |
+
print(f" Validation: {len(val_lines):>10,} ({len(val_lines)/total*100:.1f}%)")
|
| 136 |
+
print(f" Test: {len(test_lines):>10,} ({len(test_lines)/total*100:.1f}%)")
|
| 137 |
+
print()
|
| 138 |
+
|
| 139 |
+
print(" Per-source breakdown:")
|
| 140 |
+
print(f" {'Source':<25s} {'Total':>8s} {'Train':>8s} {'Val':>8s} {'Test':>8s}")
|
| 141 |
+
print(f" {'-'*25} {'-'*8} {'-'*8} {'-'*8} {'-'*8}")
|
| 142 |
+
for source in sorted(source_stats.keys()):
|
| 143 |
+
s = source_stats[source]
|
| 144 |
+
print(f" {source:<25s} {s['total']:>8,} {s['train']:>8,} {s['val']:>8,} {s['test']:>8,}")
|
| 145 |
+
print()
|
| 146 |
+
|
| 147 |
+
if not dry_run:
|
| 148 |
+
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
| 149 |
+
|
| 150 |
+
print("Writing files ...")
|
| 151 |
+
for path, lines, name in [
|
| 152 |
+
(TRAIN_FILE, train_lines, "train"),
|
| 153 |
+
(VAL_FILE, val_lines, "val"),
|
| 154 |
+
(TEST_FILE, test_lines, "test"),
|
| 155 |
+
]:
|
| 156 |
+
with open(path, "w", encoding="utf-8") as f:
|
| 157 |
+
for line in lines:
|
| 158 |
+
f.write(line + "\n")
|
| 159 |
+
size_mb = path.stat().st_size / (1024 * 1024)
|
| 160 |
+
print(f" {name:<12s} β {path.name:<20s} ({len(lines):>10,} examples, {size_mb:>8.1f} MB)")
|
| 161 |
+
|
| 162 |
+
# Save split metadata
|
| 163 |
+
meta = {
|
| 164 |
+
"total": total,
|
| 165 |
+
"train_count": len(train_lines),
|
| 166 |
+
"val_count": len(val_lines),
|
| 167 |
+
"test_count": len(test_lines),
|
| 168 |
+
"train_pct": round(len(train_lines) / total * 100, 2),
|
| 169 |
+
"val_pct": round(len(val_lines) / total * 100, 2),
|
| 170 |
+
"test_pct": round(len(test_lines) / total * 100, 2),
|
| 171 |
+
"seed": seed,
|
| 172 |
+
"source_breakdown": source_stats,
|
| 173 |
+
}
|
| 174 |
+
meta_path = OUTPUT_DIR / "split_meta.json"
|
| 175 |
+
with open(meta_path, "w", encoding="utf-8") as f:
|
| 176 |
+
json.dump(meta, f, indent=2)
|
| 177 |
+
print(f" Metadata β {meta_path.name}")
|
| 178 |
+
|
| 179 |
+
elapsed = time.time() - start
|
| 180 |
+
print(f"\n Done in {elapsed:.1f}s")
|
| 181 |
+
print("=" * 60)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
# ββ CLI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 185 |
+
|
| 186 |
+
def main():
|
| 187 |
+
parser = argparse.ArgumentParser(
|
| 188 |
+
description="MINDI Data Splitter β stratified train/val/test split",
|
| 189 |
+
)
|
| 190 |
+
parser.add_argument("--train", type=float, default=0.90, help="Train ratio (default: 0.90)")
|
| 191 |
+
parser.add_argument("--val", type=float, default=0.05, help="Validation ratio (default: 0.05)")
|
| 192 |
+
parser.add_argument("--test", type=float, default=0.05, help="Test ratio (default: 0.05)")
|
| 193 |
+
parser.add_argument("--seed", type=int, default=42, help="Random seed (default: 42)")
|
| 194 |
+
parser.add_argument("--dry-run", action="store_true", help="Preview split without writing files")
|
| 195 |
+
|
| 196 |
+
args = parser.parse_args()
|
| 197 |
+
run_split(
|
| 198 |
+
train_ratio=args.train,
|
| 199 |
+
val_ratio=args.val,
|
| 200 |
+
test_ratio=args.test,
|
| 201 |
+
seed=args.seed,
|
| 202 |
+
dry_run=args.dry_run,
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
if __name__ == "__main__":
|
| 207 |
+
main()
|
scripts/train.py
CHANGED
|
@@ -1,39 +1,348 @@
|
|
|
|
|
| 1 |
"""
|
| 2 |
-
MINDI 1.5 Vision-Coder β Training
|
| 3 |
|
| 4 |
-
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
"""
|
| 7 |
|
| 8 |
from __future__ import annotations
|
| 9 |
|
| 10 |
import argparse
|
|
|
|
|
|
|
|
|
|
| 11 |
from pathlib import Path
|
| 12 |
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
parser.add_argument(
|
| 18 |
-
"--config", type=str,
|
|
|
|
| 19 |
help="Path to training config YAML",
|
| 20 |
)
|
| 21 |
parser.add_argument(
|
| 22 |
-
"--
|
| 23 |
-
help="
|
| 24 |
)
|
| 25 |
parser.add_argument(
|
| 26 |
-
"--
|
| 27 |
-
help="
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
)
|
| 29 |
-
args = parser.parse_args()
|
| 30 |
|
| 31 |
-
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
print(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
|
| 39 |
if __name__ == "__main__":
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
+
MINDI 1.5 Vision-Coder β Master Training Script
|
| 4 |
|
| 5 |
+
Usage:
|
| 6 |
+
python scripts/train.py --phase 1 # Run phase 1 only
|
| 7 |
+
python scripts/train.py --phase all # Run all 3 phases
|
| 8 |
+
python scripts/train.py --phase 2 --resume checkpoints/training/phase1_lora_step5000
|
| 9 |
+
python scripts/train.py --dry_run # Test 10 steps only
|
| 10 |
+
python scripts/train.py --push_to_hub # Upload after training
|
| 11 |
+
|
| 12 |
+
Handles Ctrl+C gracefully: saves checkpoint before exit.
|
| 13 |
"""
|
| 14 |
|
| 15 |
from __future__ import annotations
|
| 16 |
|
| 17 |
import argparse
|
| 18 |
+
import signal
|
| 19 |
+
import sys
|
| 20 |
+
import traceback
|
| 21 |
from pathlib import Path
|
| 22 |
|
| 23 |
+
# Resolve project root (scripts/ is one level deep)
|
| 24 |
+
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
| 25 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 26 |
|
| 27 |
+
import torch
|
| 28 |
+
import yaml
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def parse_args() -> argparse.Namespace:
|
| 32 |
+
parser = argparse.ArgumentParser(
|
| 33 |
+
description="MINDI 1.5 Vision-Coder β Training",
|
| 34 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 35 |
+
)
|
| 36 |
+
parser.add_argument(
|
| 37 |
+
"--phase", type=str, default="all",
|
| 38 |
+
choices=["1", "2", "3", "all"],
|
| 39 |
+
help="Which phase(s) to run: 1, 2, 3, or all (default: all)",
|
| 40 |
+
)
|
| 41 |
+
parser.add_argument(
|
| 42 |
+
"--resume", type=str, default=None,
|
| 43 |
+
help="Path to checkpoint directory to resume from",
|
| 44 |
+
)
|
| 45 |
parser.add_argument(
|
| 46 |
+
"--config", type=str,
|
| 47 |
+
default=str(PROJECT_ROOT / "configs" / "training_config.yaml"),
|
| 48 |
help="Path to training config YAML",
|
| 49 |
)
|
| 50 |
parser.add_argument(
|
| 51 |
+
"--dry_run", action="store_true",
|
| 52 |
+
help="Test run: only 10 steps per phase",
|
| 53 |
)
|
| 54 |
parser.add_argument(
|
| 55 |
+
"--push_to_hub", action="store_true",
|
| 56 |
+
help="Push checkpoints to HuggingFace after each phase",
|
| 57 |
+
)
|
| 58 |
+
parser.add_argument(
|
| 59 |
+
"--no_wandb", action="store_true",
|
| 60 |
+
help="Disable WandB logging",
|
| 61 |
+
)
|
| 62 |
+
return parser.parse_args()
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def load_config(config_path: str) -> dict:
|
| 66 |
+
"""Load and return the training config YAML."""
|
| 67 |
+
path = Path(config_path)
|
| 68 |
+
if not path.exists():
|
| 69 |
+
raise FileNotFoundError(f"Config not found: {path}")
|
| 70 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 71 |
+
return yaml.safe_load(f)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def build_training_config(raw: dict, dry_run: bool = False):
|
| 75 |
+
"""Build TrainingConfig from parsed YAML."""
|
| 76 |
+
from src.training.mindi_trainer import PhaseConfig, TrainingConfig
|
| 77 |
+
|
| 78 |
+
training = raw.get("training", {})
|
| 79 |
+
data = raw.get("data", {})
|
| 80 |
+
output = raw.get("output", {})
|
| 81 |
+
logging_cfg = raw.get("logging", {})
|
| 82 |
+
model_cfg = raw.get("model", {})
|
| 83 |
+
|
| 84 |
+
# Build phase configs from YAML
|
| 85 |
+
phases = []
|
| 86 |
+
phase_defs = [
|
| 87 |
+
("phase1", "phase1_lora", True, False, False),
|
| 88 |
+
("phase2", "phase2_vision_bridge", False, True, True),
|
| 89 |
+
("phase3", "phase3_all", True, True, True),
|
| 90 |
+
]
|
| 91 |
+
cumulative_step = 0
|
| 92 |
+
for key, name, lora, vision, fusion in phase_defs:
|
| 93 |
+
pcfg = training.get(key, {})
|
| 94 |
+
steps = pcfg.get("steps", 2500)
|
| 95 |
+
if dry_run:
|
| 96 |
+
steps = 10
|
| 97 |
+
start = cumulative_step
|
| 98 |
+
end = cumulative_step + steps
|
| 99 |
+
phases.append(PhaseConfig(
|
| 100 |
+
name=name,
|
| 101 |
+
start_step=start,
|
| 102 |
+
end_step=end,
|
| 103 |
+
learning_rate=float(pcfg.get("lr", 2e-4)),
|
| 104 |
+
batch_size=pcfg.get("batch_size", 8),
|
| 105 |
+
gradient_accumulation_steps=training.get("grad_accumulation", 4),
|
| 106 |
+
lora=lora,
|
| 107 |
+
vision_projection=vision,
|
| 108 |
+
fusion=fusion,
|
| 109 |
+
))
|
| 110 |
+
cumulative_step = end
|
| 111 |
+
|
| 112 |
+
config = TrainingConfig(
|
| 113 |
+
train_file=PROJECT_ROOT / data.get("train_file", "data/processed/train.jsonl"),
|
| 114 |
+
val_file=PROJECT_ROOT / data.get("val_file", "data/processed/val.jsonl"),
|
| 115 |
+
output_dir=PROJECT_ROOT / output.get("checkpoint_dir", "checkpoints/training"),
|
| 116 |
+
log_dir=PROJECT_ROOT / logging_cfg.get("log_dir", "logs/training"),
|
| 117 |
+
max_seq_length=data.get("max_length", 4096),
|
| 118 |
+
use_compile=model_cfg.get("use_compile", False),
|
| 119 |
+
gradient_checkpointing=model_cfg.get("gradient_checkpointing", True),
|
| 120 |
+
dtype=model_cfg.get("dtype", "bf16"),
|
| 121 |
+
num_workers=data.get("num_workers", 4),
|
| 122 |
+
pin_memory=True,
|
| 123 |
+
prefetch_factor=2,
|
| 124 |
+
weight_decay=0.01,
|
| 125 |
+
warmup_ratio=0.03,
|
| 126 |
+
max_grad_norm=float(training.get("max_grad_norm", 1.0)),
|
| 127 |
+
seed=42,
|
| 128 |
+
log_every_n_steps=logging_cfg.get("log_every", 10),
|
| 129 |
+
eval_every_n_steps=training.get("eval_every", 250),
|
| 130 |
+
save_every_n_steps=training.get("save_every", 500),
|
| 131 |
+
phases=phases,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
if dry_run:
|
| 135 |
+
config.eval_every_n_steps = 5
|
| 136 |
+
config.save_every_n_steps = 10
|
| 137 |
+
config.log_every_n_steps = 1
|
| 138 |
+
|
| 139 |
+
return config
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def init_wandb(raw_config: dict, phase: str, disabled: bool = False):
|
| 143 |
+
"""Initialize WandB logging."""
|
| 144 |
+
if disabled:
|
| 145 |
+
return None
|
| 146 |
+
try:
|
| 147 |
+
import wandb
|
| 148 |
+
logging_cfg = raw_config.get("logging", {})
|
| 149 |
+
run = wandb.init(
|
| 150 |
+
project=logging_cfg.get("wandb_project", "mindi-1.5-vision-coder"),
|
| 151 |
+
entity=logging_cfg.get("wandb_entity", "mindigenous"),
|
| 152 |
+
name=f"mindi15-{phase}",
|
| 153 |
+
config=raw_config,
|
| 154 |
+
tags=["mindi-1.5", "training", f"phase-{phase}"],
|
| 155 |
+
reinit=True,
|
| 156 |
+
)
|
| 157 |
+
print(f"[train.py] WandB initialized: {run.url}")
|
| 158 |
+
return run
|
| 159 |
+
except ImportError:
|
| 160 |
+
print("[train.py] WandB not installed β logging disabled")
|
| 161 |
+
return None
|
| 162 |
+
except Exception as e:
|
| 163 |
+
print(f"[train.py] WandB init failed: {e} β continuing without logging")
|
| 164 |
+
return None
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def push_checkpoint_to_hub(checkpoint_dir: Path, raw_config: dict) -> None:
|
| 168 |
+
"""Push a checkpoint to HuggingFace Hub."""
|
| 169 |
+
output = raw_config.get("output", {})
|
| 170 |
+
repo_id = output.get("hf_repo", "Mindigenous/MINDI-1.5-Vision-Coder")
|
| 171 |
+
|
| 172 |
+
try:
|
| 173 |
+
from huggingface_hub import HfApi
|
| 174 |
+
import os
|
| 175 |
+
api = HfApi(token=os.environ.get("HF_TOKEN"))
|
| 176 |
+
|
| 177 |
+
print(f"[train.py] Pushing checkpoint to {repo_id} ...")
|
| 178 |
+
api.upload_folder(
|
| 179 |
+
folder_path=str(checkpoint_dir),
|
| 180 |
+
repo_id=repo_id,
|
| 181 |
+
path_in_repo=f"checkpoints/{checkpoint_dir.name}",
|
| 182 |
+
repo_type="model",
|
| 183 |
+
)
|
| 184 |
+
print(f"[train.py] Pushed to https://huggingface.co/{repo_id}")
|
| 185 |
+
except ImportError:
|
| 186 |
+
print("[train.py] huggingface_hub not installed β skipping push")
|
| 187 |
+
except Exception as e:
|
| 188 |
+
print(f"[train.py] Push to hub failed: {e}")
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def log_wandb_phase_complete(wandb_run, summary: dict) -> None:
|
| 192 |
+
"""Log phase completion to WandB."""
|
| 193 |
+
if wandb_run is None:
|
| 194 |
+
return
|
| 195 |
+
try:
|
| 196 |
+
import wandb
|
| 197 |
+
wandb.log({
|
| 198 |
+
"phase_complete": True,
|
| 199 |
+
"phase": summary.get("phase", "unknown"),
|
| 200 |
+
"total_steps": summary.get("total_steps", 0),
|
| 201 |
+
"best_val_loss": summary.get("best_val_loss", 0),
|
| 202 |
+
"elapsed_minutes": summary.get("elapsed_minutes", 0),
|
| 203 |
+
})
|
| 204 |
+
except Exception:
|
| 205 |
+
pass
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def main() -> None:
|
| 209 |
+
args = parse_args()
|
| 210 |
+
|
| 211 |
+
print()
|
| 212 |
+
print("=" * 60)
|
| 213 |
+
print(" MINDI 1.5 Vision-Coder β Training Launch")
|
| 214 |
+
print(" MINDIGENOUS.AI")
|
| 215 |
+
print("=" * 60)
|
| 216 |
+
print()
|
| 217 |
+
print(f" Phase: {args.phase}")
|
| 218 |
+
print(f" Config: {args.config}")
|
| 219 |
+
print(f" Resume: {args.resume or 'None'}")
|
| 220 |
+
print(f" Dry run: {args.dry_run}")
|
| 221 |
+
print(f" Push to hub: {args.push_to_hub}")
|
| 222 |
+
print(f" Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")
|
| 223 |
+
if torch.cuda.is_available():
|
| 224 |
+
print(f" GPU: {torch.cuda.get_device_name(0)}")
|
| 225 |
+
vram_gb = torch.cuda.get_device_properties(0).total_mem / (1024 ** 3)
|
| 226 |
+
print(f" VRAM: {vram_gb:.1f} GB")
|
| 227 |
+
print()
|
| 228 |
+
|
| 229 |
+
# Load config
|
| 230 |
+
raw_config = load_config(args.config)
|
| 231 |
+
config = build_training_config(raw_config, dry_run=args.dry_run)
|
| 232 |
+
|
| 233 |
+
# Filter phases based on --phase arg
|
| 234 |
+
if args.phase != "all":
|
| 235 |
+
phase_idx = int(args.phase) - 1
|
| 236 |
+
if phase_idx < 0 or phase_idx >= len(config.phases):
|
| 237 |
+
print(f"ERROR: Invalid phase {args.phase}. Available: 1-{len(config.phases)}")
|
| 238 |
+
sys.exit(1)
|
| 239 |
+
selected_phase = config.phases[phase_idx]
|
| 240 |
+
# Adjust to start from 0 for single-phase run
|
| 241 |
+
step_count = selected_phase.end_step - selected_phase.start_step
|
| 242 |
+
selected_phase.start_step = 0
|
| 243 |
+
selected_phase.end_step = step_count
|
| 244 |
+
config.phases = [selected_phase]
|
| 245 |
+
|
| 246 |
+
# Initialize model
|
| 247 |
+
print("[train.py] Initializing MINDI 1.5 model ...")
|
| 248 |
+
from src.model.mindi_model import MINDI15
|
| 249 |
+
model_cfg = raw_config.get("model", {})
|
| 250 |
+
vision_cfg = raw_config.get("vision", {})
|
| 251 |
+
|
| 252 |
+
model = MINDI15(
|
| 253 |
+
model_name=model_cfg.get("name", "Qwen/Qwen2.5-Coder-7B-Instruct"),
|
| 254 |
+
clip_model=vision_cfg.get("clip_model", "openai/clip-vit-large-patch14"),
|
| 255 |
+
hidden_size=model_cfg.get("hidden_size", 4096),
|
| 256 |
+
num_visual_tokens=vision_cfg.get("visual_tokens", 256),
|
| 257 |
+
torch_dtype=config.torch_dtype,
|
| 258 |
)
|
|
|
|
| 259 |
|
| 260 |
+
# Initialize trainer
|
| 261 |
+
from src.training.mindi_trainer import MINDITrainer
|
| 262 |
+
trainer = MINDITrainer(model=model, config=config)
|
| 263 |
+
|
| 264 |
+
# Resume from checkpoint
|
| 265 |
+
if args.resume:
|
| 266 |
+
resume_path = Path(args.resume)
|
| 267 |
+
if not resume_path.is_absolute():
|
| 268 |
+
resume_path = PROJECT_ROOT / resume_path
|
| 269 |
+
trainer.resume_from_checkpoint(resume_path)
|
| 270 |
+
|
| 271 |
+
# Initialize WandB
|
| 272 |
+
wandb_run = init_wandb(raw_config, args.phase, disabled=args.no_wandb)
|
| 273 |
+
|
| 274 |
+
# Graceful Ctrl+C handler
|
| 275 |
+
interrupted = False
|
| 276 |
+
|
| 277 |
+
def signal_handler(sig, frame):
|
| 278 |
+
nonlocal interrupted
|
| 279 |
+
if interrupted:
|
| 280 |
+
print("\n[train.py] Forced exit!")
|
| 281 |
+
sys.exit(1)
|
| 282 |
+
interrupted = True
|
| 283 |
+
print("\n[train.py] Ctrl+C received β saving checkpoint before exit ...")
|
| 284 |
+
try:
|
| 285 |
+
emergency_dir = config.output_dir / "emergency_checkpoint"
|
| 286 |
+
emergency_dir.mkdir(parents=True, exist_ok=True)
|
| 287 |
+
model.save(emergency_dir)
|
| 288 |
+
print(f"[train.py] Emergency checkpoint saved: {emergency_dir}")
|
| 289 |
+
except Exception as e:
|
| 290 |
+
print(f"[train.py] Emergency save failed: {e}")
|
| 291 |
+
sys.exit(0)
|
| 292 |
+
|
| 293 |
+
signal.signal(signal.SIGINT, signal_handler)
|
| 294 |
+
|
| 295 |
+
# Run training
|
| 296 |
+
try:
|
| 297 |
+
if args.phase == "all":
|
| 298 |
+
summary = trainer.train()
|
| 299 |
+
|
| 300 |
+
final_dir = config.output_dir / "final"
|
| 301 |
+
if args.push_to_hub:
|
| 302 |
+
push_checkpoint_to_hub(final_dir, raw_config)
|
| 303 |
+
log_wandb_phase_complete(wandb_run, summary)
|
| 304 |
+
|
| 305 |
+
else:
|
| 306 |
+
phase = config.phases[0]
|
| 307 |
+
summary = trainer.train_phase(phase)
|
| 308 |
+
|
| 309 |
+
ckpt_dir = config.output_dir / f"{phase.name}_step{phase.end_step}"
|
| 310 |
+
if args.push_to_hub:
|
| 311 |
+
push_checkpoint_to_hub(ckpt_dir, raw_config)
|
| 312 |
+
log_wandb_phase_complete(wandb_run, summary)
|
| 313 |
+
|
| 314 |
+
except KeyboardInterrupt:
|
| 315 |
+
signal_handler(None, None)
|
| 316 |
+
except Exception as e:
|
| 317 |
+
print(f"\n[train.py] ERROR: {e}")
|
| 318 |
+
traceback.print_exc()
|
| 319 |
+
try:
|
| 320 |
+
crash_dir = config.output_dir / "crash_checkpoint"
|
| 321 |
+
crash_dir.mkdir(parents=True, exist_ok=True)
|
| 322 |
+
model.save(crash_dir)
|
| 323 |
+
print(f"[train.py] Crash checkpoint saved: {crash_dir}")
|
| 324 |
+
except Exception:
|
| 325 |
+
pass
|
| 326 |
+
sys.exit(1)
|
| 327 |
+
finally:
|
| 328 |
+
if wandb_run is not None:
|
| 329 |
+
try:
|
| 330 |
+
import wandb
|
| 331 |
+
wandb.finish()
|
| 332 |
+
except Exception:
|
| 333 |
+
pass
|
| 334 |
|
| 335 |
+
# Final summary
|
| 336 |
+
hf_repo = raw_config.get("output", {}).get("hf_repo", "Mindigenous/MINDI-1.5-Vision-Coder")
|
| 337 |
+
print()
|
| 338 |
+
print("=" * 60)
|
| 339 |
+
print(" Training complete!")
|
| 340 |
+
print(f" Best val loss: {trainer.best_val_loss:.4f}")
|
| 341 |
+
print(f" Checkpoint at: {config.output_dir}")
|
| 342 |
+
if args.push_to_hub:
|
| 343 |
+
print(f" HuggingFace: https://huggingface.co/{hf_repo}")
|
| 344 |
+
print("=" * 60)
|
| 345 |
+
print()
|
| 346 |
|
| 347 |
|
| 348 |
if __name__ == "__main__":
|
scripts/upload_everything_to_hf.py
ADDED
|
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Upload ENTIRE MINDI 1.5 Vision-Coder project to HuggingFace.
|
| 4 |
+
|
| 5 |
+
REPO 1 (model): Mindigenous/MINDI-1.5-Vision-Coder
|
| 6 |
+
REPO 2 (dataset): Mindigenous/MINDI-1.5-training-data
|
| 7 |
+
|
| 8 |
+
Both private. On MI300X we will clone these repos directly.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
import sys
|
| 13 |
+
import time
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
|
| 16 |
+
from dotenv import load_dotenv
|
| 17 |
+
from huggingface_hub import HfApi, create_repo
|
| 18 |
+
|
| 19 |
+
# ββ Paths ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 20 |
+
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
| 21 |
+
ENV_FILE = PROJECT_ROOT / ".env"
|
| 22 |
+
|
| 23 |
+
# ββ Repo names βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 24 |
+
MODEL_REPO = "Mindigenous/MINDI-1.5-Vision-Coder"
|
| 25 |
+
DATASET_REPO = "Mindigenous/MINDI-1.5-training-data"
|
| 26 |
+
|
| 27 |
+
# ββ Model card (written to repo as README.md) βββββββββββββββββββββββββ
|
| 28 |
+
MODEL_CARD = """\
|
| 29 |
+
---
|
| 30 |
+
license: apache-2.0
|
| 31 |
+
language:
|
| 32 |
+
- en
|
| 33 |
+
tags:
|
| 34 |
+
- code-generation
|
| 35 |
+
- nextjs
|
| 36 |
+
- react
|
| 37 |
+
- typescript
|
| 38 |
+
- vision
|
| 39 |
+
- multimodal
|
| 40 |
+
- mindi
|
| 41 |
+
- mindigenous
|
| 42 |
+
base_model: Qwen/Qwen2.5-Coder-7B-Instruct
|
| 43 |
+
---
|
| 44 |
+
|
| 45 |
+
# MINDI 1.5 Vision-Coder
|
| 46 |
+
|
| 47 |
+
Built by MINDIGENOUS.AI
|
| 48 |
+
|
| 49 |
+
## Model Description
|
| 50 |
+
MINDI 1.5 is an agentic AI coding model
|
| 51 |
+
that sees its own output and critiques it.
|
| 52 |
+
|
| 53 |
+
## Key Features
|
| 54 |
+
- Generates Next.js 14 + Tailwind + TypeScript
|
| 55 |
+
- Sees screenshots via CLIP ViT-L/14
|
| 56 |
+
- Critiques its own UI/UX output
|
| 57 |
+
- Searches internet for latest packages
|
| 58 |
+
- Tests code in sandbox environment
|
| 59 |
+
- Self-fixes errors automatically
|
| 60 |
+
|
| 61 |
+
## Training
|
| 62 |
+
- Base: Qwen/Qwen2.5-Coder-7B-Instruct
|
| 63 |
+
- Method: LoRA fine-tuning
|
| 64 |
+
- Hardware: AMD MI300X 192GB VRAM
|
| 65 |
+
- Dataset: 1,449,428 examples
|
| 66 |
+
- Tokens: 859,694,776
|
| 67 |
+
- Status: Training in progress
|
| 68 |
+
|
| 69 |
+
## Built By
|
| 70 |
+
Faaz - MINDIGENOUS.AI
|
| 71 |
+
Mumbai, India
|
| 72 |
+
April 2026
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
# ββ Dataset card βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 76 |
+
DATASET_CARD = """\
|
| 77 |
+
---
|
| 78 |
+
license: apache-2.0
|
| 79 |
+
language:
|
| 80 |
+
- en
|
| 81 |
+
tags:
|
| 82 |
+
- code-generation
|
| 83 |
+
- nextjs
|
| 84 |
+
- react
|
| 85 |
+
- typescript
|
| 86 |
+
- vision
|
| 87 |
+
- multimodal
|
| 88 |
+
- mindi
|
| 89 |
+
- mindigenous
|
| 90 |
+
size_categories:
|
| 91 |
+
- 1M<n<10M
|
| 92 |
+
---
|
| 93 |
+
|
| 94 |
+
# MINDI 1.5 Training Data
|
| 95 |
+
|
| 96 |
+
Training dataset for **MINDI 1.5 Vision-Coder** by MINDIGENOUS.AI
|
| 97 |
+
|
| 98 |
+
## Dataset Statistics
|
| 99 |
+
| Metric | Value |
|
| 100 |
+
|--------|-------|
|
| 101 |
+
| Total examples | 1,449,428 |
|
| 102 |
+
| Total tokens | 859,694,776 |
|
| 103 |
+
| Avg tokens/example | 593 |
|
| 104 |
+
| Avg quality score | 6.49 |
|
| 105 |
+
| Sources | 9 |
|
| 106 |
+
|
| 107 |
+
## Splits
|
| 108 |
+
| Split | Examples | Percentage |
|
| 109 |
+
|-------|----------|------------|
|
| 110 |
+
| Train | 1,304,486 | 90.0% |
|
| 111 |
+
| Validation | 72,471 | 5.0% |
|
| 112 |
+
| Test | 72,471 | 5.0% |
|
| 113 |
+
|
| 114 |
+
## Sources
|
| 115 |
+
| Source | Examples | Kept % |
|
| 116 |
+
|--------|----------|--------|
|
| 117 |
+
| starcoderdata | 569,350 | 94.9% |
|
| 118 |
+
| websight | 250,987 | 99.99% |
|
| 119 |
+
| evol_code | 155,998 | 99.7% |
|
| 120 |
+
| codefeedback | 149,865 | 99.9% |
|
| 121 |
+
| magicoder | 149,987 | 99.99% |
|
| 122 |
+
| synthetic_nextjs | 90,000 | 100% (protected) |
|
| 123 |
+
| codealpaca | 59,241 | 98.8% |
|
| 124 |
+
| search_examples | 15,000 | 100% (protected) |
|
| 125 |
+
| sandbox_examples | 9,000 | 100% (protected) |
|
| 126 |
+
|
| 127 |
+
## Type Distribution
|
| 128 |
+
| Type | Examples |
|
| 129 |
+
|------|----------|
|
| 130 |
+
| code_generation | 1,183,441 |
|
| 131 |
+
| vision_code | 250,987 |
|
| 132 |
+
| search | 15,000 |
|
| 133 |
+
|
| 134 |
+
## Language Distribution
|
| 135 |
+
| Language | Examples |
|
| 136 |
+
|----------|----------|
|
| 137 |
+
| unknown | 490,305 |
|
| 138 |
+
| typescript | 375,859 |
|
| 139 |
+
| javascript | 298,497 |
|
| 140 |
+
| python | 211,842 |
|
| 141 |
+
| html | 36,371 |
|
| 142 |
+
| java | 32,458 |
|
| 143 |
+
| rust | 3,709 |
|
| 144 |
+
| go | 387 |
|
| 145 |
+
|
| 146 |
+
## Format
|
| 147 |
+
Each example is a JSON object with:
|
| 148 |
+
- `conversations`: list of `{"role": ..., "content": ...}` turns
|
| 149 |
+
- `source`: dataset origin
|
| 150 |
+
- `type`: code_generation / vision_code / search
|
| 151 |
+
- `language`: programming language
|
| 152 |
+
- `quality_score`: heuristic quality (0-10+)
|
| 153 |
+
- `token_count`: number of tokens
|
| 154 |
+
|
| 155 |
+
## Quality Filtering
|
| 156 |
+
- Protected sources (sandbox, search, synthetic_nextjs) bypass aggressive filters
|
| 157 |
+
- MINDI special token bonuses boost agentic examples
|
| 158 |
+
- Dedup via SHA-256 content hashing
|
| 159 |
+
- Rejection reasons: too_many_tokens (30,637), boilerplate (1,373), duplicate (59)
|
| 160 |
+
|
| 161 |
+
## Built By
|
| 162 |
+
Faaz - MINDIGENOUS.AI
|
| 163 |
+
Mumbai, India β April 2026
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 168 |
+
def load_token() -> str:
|
| 169 |
+
"""Load HF token from .env."""
|
| 170 |
+
load_dotenv(ENV_FILE)
|
| 171 |
+
token = os.getenv("HUGGINGFACE_TOKEN") or os.getenv("HF_TOKEN")
|
| 172 |
+
if not token:
|
| 173 |
+
print("ERROR: No HUGGINGFACE_TOKEN or HF_TOKEN found in .env")
|
| 174 |
+
sys.exit(1)
|
| 175 |
+
return token
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def ensure_repo(api: HfApi, repo_id: str, repo_type: str, token: str):
|
| 179 |
+
"""Create repo if it doesn't exist."""
|
| 180 |
+
try:
|
| 181 |
+
create_repo(
|
| 182 |
+
repo_id=repo_id,
|
| 183 |
+
repo_type=repo_type,
|
| 184 |
+
private=True,
|
| 185 |
+
token=token,
|
| 186 |
+
exist_ok=True,
|
| 187 |
+
)
|
| 188 |
+
print(f" Repo ready: {repo_id} ({repo_type})")
|
| 189 |
+
except Exception as e:
|
| 190 |
+
print(f" Repo create/check: {e}")
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def upload_folder(api: HfApi, local: Path, remote: str, repo_id: str,
|
| 194 |
+
repo_type: str, token: str):
|
| 195 |
+
"""Upload a local folder to HF repo."""
|
| 196 |
+
if not local.exists():
|
| 197 |
+
print(f" SKIP (not found): {local}")
|
| 198 |
+
return
|
| 199 |
+
label = str(local.relative_to(PROJECT_ROOT))
|
| 200 |
+
print(f" Uploading {label}/ to {repo_type} repo ... ", end="", flush=True)
|
| 201 |
+
t0 = time.time()
|
| 202 |
+
api.upload_folder(
|
| 203 |
+
repo_id=repo_id,
|
| 204 |
+
repo_type=repo_type,
|
| 205 |
+
folder_path=str(local),
|
| 206 |
+
path_in_repo=remote,
|
| 207 |
+
token=token,
|
| 208 |
+
ignore_patterns=["__pycache__", "*.pyc", ".git"],
|
| 209 |
+
)
|
| 210 |
+
print(f"done ({time.time() - t0:.1f}s)")
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def upload_file(api: HfApi, local: Path, remote: str, repo_id: str,
|
| 214 |
+
repo_type: str, token: str):
|
| 215 |
+
"""Upload a single file to HF repo."""
|
| 216 |
+
if not local.exists():
|
| 217 |
+
print(f" SKIP (not found): {local.name}")
|
| 218 |
+
return
|
| 219 |
+
size_mb = local.stat().st_size / (1024 * 1024)
|
| 220 |
+
label = str(local.relative_to(PROJECT_ROOT))
|
| 221 |
+
print(f" Uploading {label} ({size_mb:.1f} MB) to {repo_type} repo ... ",
|
| 222 |
+
end="", flush=True)
|
| 223 |
+
t0 = time.time()
|
| 224 |
+
api.upload_file(
|
| 225 |
+
repo_id=repo_id,
|
| 226 |
+
repo_type=repo_type,
|
| 227 |
+
path_or_fileobj=str(local),
|
| 228 |
+
path_in_repo=remote,
|
| 229 |
+
token=token,
|
| 230 |
+
)
|
| 231 |
+
print(f"done ({time.time() - t0:.1f}s)")
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def upload_readme(api: HfApi, content: str, repo_id: str,
|
| 235 |
+
repo_type: str, token: str):
|
| 236 |
+
"""Upload a README.md string to a repo."""
|
| 237 |
+
print(f" Uploading README.md to {repo_type} repo ... ", end="", flush=True)
|
| 238 |
+
api.upload_file(
|
| 239 |
+
repo_id=repo_id,
|
| 240 |
+
repo_type=repo_type,
|
| 241 |
+
path_or_fileobj=content.encode("utf-8"),
|
| 242 |
+
path_in_repo="README.md",
|
| 243 |
+
token=token,
|
| 244 |
+
)
|
| 245 |
+
print("done")
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 249 |
+
def main():
|
| 250 |
+
print("=" * 60)
|
| 251 |
+
print(" MINDI 1.5 β Upload Everything to HuggingFace")
|
| 252 |
+
print("=" * 60)
|
| 253 |
+
print()
|
| 254 |
+
|
| 255 |
+
token = load_token()
|
| 256 |
+
api = HfApi()
|
| 257 |
+
|
| 258 |
+
# ββ Create repos βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 259 |
+
print("[1/4] Creating repositories ...")
|
| 260 |
+
ensure_repo(api, MODEL_REPO, "model", token)
|
| 261 |
+
ensure_repo(api, DATASET_REPO, "dataset", token)
|
| 262 |
+
print()
|
| 263 |
+
|
| 264 |
+
# ββ REPO 1: Model (code + configs) βββββββββββββββββββββββββββββ
|
| 265 |
+
print("[2/4] Uploading to MODEL repo:", MODEL_REPO)
|
| 266 |
+
print("-" * 50)
|
| 267 |
+
|
| 268 |
+
# Folders
|
| 269 |
+
model_folders = [
|
| 270 |
+
(PROJECT_ROOT / "src", "src"),
|
| 271 |
+
(PROJECT_ROOT / "scripts", "scripts"),
|
| 272 |
+
(PROJECT_ROOT / "configs", "configs"),
|
| 273 |
+
(PROJECT_ROOT / "data" / "tokenizer", "data/tokenizer"),
|
| 274 |
+
(PROJECT_ROOT / "tests", "tests"),
|
| 275 |
+
(PROJECT_ROOT / "api", "api"),
|
| 276 |
+
]
|
| 277 |
+
for local, remote in model_folders:
|
| 278 |
+
upload_folder(api, local, remote, MODEL_REPO, "model", token)
|
| 279 |
+
|
| 280 |
+
# Single files
|
| 281 |
+
model_files = [
|
| 282 |
+
(PROJECT_ROOT / "requirements.txt", "requirements.txt"),
|
| 283 |
+
(PROJECT_ROOT / "setup.py", "setup.py"),
|
| 284 |
+
(PROJECT_ROOT / "activate_mindi.bat", "activate_mindi.bat"),
|
| 285 |
+
(PROJECT_ROOT / ".env.example", ".env.example"),
|
| 286 |
+
]
|
| 287 |
+
for local, remote in model_files:
|
| 288 |
+
upload_file(api, local, remote, MODEL_REPO, "model", token)
|
| 289 |
+
|
| 290 |
+
# setup_mi300x.sh
|
| 291 |
+
mi300x_sh = PROJECT_ROOT / "setup_mi300x.sh"
|
| 292 |
+
if mi300x_sh.exists():
|
| 293 |
+
upload_file(api, mi300x_sh, "setup_mi300x.sh", MODEL_REPO, "model", token)
|
| 294 |
+
|
| 295 |
+
# Model card replaces README.md
|
| 296 |
+
upload_readme(api, MODEL_CARD, MODEL_REPO, "model", token)
|
| 297 |
+
print()
|
| 298 |
+
|
| 299 |
+
# ββ REPO 2: Dataset ββββββββββββββββββββββββββββββββββββββββββββ
|
| 300 |
+
print("[3/4] Uploading to DATASET repo:", DATASET_REPO)
|
| 301 |
+
print("-" * 50)
|
| 302 |
+
|
| 303 |
+
processed = PROJECT_ROOT / "data" / "processed"
|
| 304 |
+
dataset_files = [
|
| 305 |
+
(processed / "train.jsonl", "processed/train.jsonl"),
|
| 306 |
+
(processed / "val.jsonl", "processed/val.jsonl"),
|
| 307 |
+
(processed / "test.jsonl", "processed/test.jsonl"),
|
| 308 |
+
(processed / "mindi_filtered.jsonl", "processed/mindi_filtered.jsonl"),
|
| 309 |
+
(processed / "filter_report.json", "processed/filter_report.json"),
|
| 310 |
+
(processed / "split_meta.json", "processed/split_meta.json"),
|
| 311 |
+
]
|
| 312 |
+
for local, remote in dataset_files:
|
| 313 |
+
upload_file(api, local, remote, DATASET_REPO, "dataset", token)
|
| 314 |
+
|
| 315 |
+
# Raw data folder
|
| 316 |
+
upload_folder(
|
| 317 |
+
api, PROJECT_ROOT / "data" / "raw", "raw",
|
| 318 |
+
DATASET_REPO, "dataset", token,
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
# Tokenizer copy in dataset repo
|
| 322 |
+
upload_folder(
|
| 323 |
+
api, PROJECT_ROOT / "data" / "tokenizer", "tokenizer",
|
| 324 |
+
DATASET_REPO, "dataset", token,
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
# Dataset card
|
| 328 |
+
upload_readme(api, DATASET_CARD, DATASET_REPO, "dataset", token)
|
| 329 |
+
print()
|
| 330 |
+
|
| 331 |
+
# ββ Done βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 332 |
+
print("[4/4] Upload complete!")
|
| 333 |
+
print()
|
| 334 |
+
print("ββββββββββββββββββββββββββββββββββββββββ")
|
| 335 |
+
print("β UPLOAD COMPLETE! β")
|
| 336 |
+
print("β β")
|
| 337 |
+
print("β Model repo: β")
|
| 338 |
+
print("β huggingface.co/Mindigenous/ β")
|
| 339 |
+
print("β MINDI-1.5-Vision-Coder β")
|
| 340 |
+
print("β β")
|
| 341 |
+
print("β Dataset repo: β")
|
| 342 |
+
print("β huggingface.co/datasets/ β")
|
| 343 |
+
print("β Mindigenous/MINDI-1.5-training-data β")
|
| 344 |
+
print("β β")
|
| 345 |
+
print("β On MI300X just run: β")
|
| 346 |
+
print("β git clone https://huggingface.co/ β")
|
| 347 |
+
print("β Mindigenous/MINDI-1.5-Vision-Coder β")
|
| 348 |
+
print("β β")
|
| 349 |
+
print("β Ready to train! π β")
|
| 350 |
+
print("ββββββββββββββββββββββββββββββββββββββββ")
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
if __name__ == "__main__":
|
| 354 |
+
main()
|
setup_mi300x.sh
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# ============================================================
|
| 3 |
+
# MINDI 1.5 Vision-Coder β MI300X Setup Script
|
| 4 |
+
# One command to set up everything on DigitalOcean AMD MI300X
|
| 5 |
+
# ============================================================
|
| 6 |
+
set -e
|
| 7 |
+
|
| 8 |
+
echo "============================================================"
|
| 9 |
+
echo " MINDI 1.5 Vision-Coder β MI300X Setup"
|
| 10 |
+
echo " MINDIGENOUS.AI"
|
| 11 |
+
echo "============================================================"
|
| 12 |
+
echo ""
|
| 13 |
+
|
| 14 |
+
# ββ Check HF_TOKEN βββββββββββββββββββββββββββββββββββββββββββββ
|
| 15 |
+
if [ -z "$HF_TOKEN" ]; then
|
| 16 |
+
echo "ERROR: Set HF_TOKEN environment variable first!"
|
| 17 |
+
echo " export HF_TOKEN=hf_your_token_here"
|
| 18 |
+
exit 1
|
| 19 |
+
fi
|
| 20 |
+
|
| 21 |
+
# ββ Step 1: Install ROCm PyTorch βββββββββββββββββββββββββββββββ
|
| 22 |
+
echo "[1/7] Installing ROCm PyTorch (ROCm 6.0) ..."
|
| 23 |
+
pip install torch torchvision torchaudio \
|
| 24 |
+
--index-url https://download.pytorch.org/whl/rocm6.0
|
| 25 |
+
|
| 26 |
+
# ββ Step 2: Clone the full project from HF ββββββββββββββββββββ
|
| 27 |
+
echo ""
|
| 28 |
+
echo "[2/7] Cloning MINDI 1.5 from HuggingFace ..."
|
| 29 |
+
if [ -d "MINDI-1.5-Vision-Coder" ]; then
|
| 30 |
+
echo " Directory exists β pulling latest ..."
|
| 31 |
+
cd MINDI-1.5-Vision-Coder
|
| 32 |
+
git pull
|
| 33 |
+
else
|
| 34 |
+
git clone https://${HF_TOKEN}@huggingface.co/Mindigenous/MINDI-1.5-Vision-Coder
|
| 35 |
+
cd MINDI-1.5-Vision-Coder
|
| 36 |
+
fi
|
| 37 |
+
|
| 38 |
+
# ββ Step 3: Install Python requirements ββββββββββββββββββββββββ
|
| 39 |
+
echo ""
|
| 40 |
+
echo "[3/7] Installing Python requirements ..."
|
| 41 |
+
pip install -r requirements.txt
|
| 42 |
+
|
| 43 |
+
# Additional training dependencies
|
| 44 |
+
pip install wandb huggingface_hub accelerate
|
| 45 |
+
|
| 46 |
+
# ββ Step 4: Download training data from HF βββββββββββββββββββββ
|
| 47 |
+
echo ""
|
| 48 |
+
echo "[4/7] Downloading training dataset ..."
|
| 49 |
+
python -c "
|
| 50 |
+
from huggingface_hub import snapshot_download
|
| 51 |
+
import os
|
| 52 |
+
|
| 53 |
+
snapshot_download(
|
| 54 |
+
repo_id='Mindigenous/MINDI-1.5-training-data',
|
| 55 |
+
repo_type='dataset',
|
| 56 |
+
local_dir='data/',
|
| 57 |
+
token=os.environ['HF_TOKEN']
|
| 58 |
+
)
|
| 59 |
+
print('Dataset downloaded!')
|
| 60 |
+
"
|
| 61 |
+
|
| 62 |
+
# Verify data files exist
|
| 63 |
+
echo " Checking data files ..."
|
| 64 |
+
if [ ! -f "data/processed/train.jsonl" ]; then
|
| 65 |
+
echo " ERROR: train.jsonl not found!"
|
| 66 |
+
exit 1
|
| 67 |
+
fi
|
| 68 |
+
if [ ! -f "data/processed/val.jsonl" ]; then
|
| 69 |
+
echo " ERROR: val.jsonl not found!"
|
| 70 |
+
exit 1
|
| 71 |
+
fi
|
| 72 |
+
TRAIN_SIZE=$(du -sh data/processed/train.jsonl | cut -f1)
|
| 73 |
+
VAL_SIZE=$(du -sh data/processed/val.jsonl | cut -f1)
|
| 74 |
+
echo " train.jsonl: ${TRAIN_SIZE}"
|
| 75 |
+
echo " val.jsonl: ${VAL_SIZE}"
|
| 76 |
+
|
| 77 |
+
# ββ Step 5: Set environment variables ββββββββββββββββββββββββββ
|
| 78 |
+
echo ""
|
| 79 |
+
echo "[5/7] Setting environment variables ..."
|
| 80 |
+
|
| 81 |
+
# ROCm / PyTorch settings
|
| 82 |
+
export HSA_OVERRIDE_GFX_VERSION=11.0.0
|
| 83 |
+
export PYTORCH_ROCM_ARCH="gfx942"
|
| 84 |
+
export HIP_VISIBLE_DEVICES=0
|
| 85 |
+
export TOKENIZERS_PARALLELISM=false
|
| 86 |
+
export WANDB_PROJECT="mindi-1.5-vision-coder"
|
| 87 |
+
|
| 88 |
+
# Create .env file
|
| 89 |
+
cat > .env << EOF
|
| 90 |
+
HF_TOKEN=${HF_TOKEN}
|
| 91 |
+
HSA_OVERRIDE_GFX_VERSION=11.0.0
|
| 92 |
+
PYTORCH_ROCM_ARCH=gfx942
|
| 93 |
+
HIP_VISIBLE_DEVICES=0
|
| 94 |
+
TOKENIZERS_PARALLELISM=false
|
| 95 |
+
WANDB_PROJECT=mindi-1.5-vision-coder
|
| 96 |
+
EOF
|
| 97 |
+
echo " .env file created"
|
| 98 |
+
|
| 99 |
+
# ββ Step 6: Verify GPU detected βββββββββββββββββββββββββββββββ
|
| 100 |
+
echo ""
|
| 101 |
+
echo "[6/7] Verifying GPU ..."
|
| 102 |
+
python -c "
|
| 103 |
+
import torch
|
| 104 |
+
print(f' PyTorch version: {torch.__version__}')
|
| 105 |
+
print(f' CUDA available: {torch.cuda.is_available()}')
|
| 106 |
+
if torch.cuda.is_available():
|
| 107 |
+
print(f' GPU name: {torch.cuda.get_device_name(0)}')
|
| 108 |
+
vram = torch.cuda.get_device_properties(0).total_mem / (1024**3)
|
| 109 |
+
print(f' VRAM: {vram:.1f} GB')
|
| 110 |
+
print(f' ROCm backend: {torch.version.hip is not None}')
|
| 111 |
+
else:
|
| 112 |
+
print(' WARNING: No GPU detected!')
|
| 113 |
+
exit(1)
|
| 114 |
+
"
|
| 115 |
+
|
| 116 |
+
# Quick bf16 test
|
| 117 |
+
python -c "
|
| 118 |
+
import torch
|
| 119 |
+
x = torch.randn(100, 100, dtype=torch.bfloat16, device='cuda')
|
| 120 |
+
y = torch.matmul(x, x.T)
|
| 121 |
+
print(f' bf16 matmul test: PASSED (shape={y.shape})')
|
| 122 |
+
"
|
| 123 |
+
|
| 124 |
+
# ββ Step 7: Create output directories βββββββββββββββββββββββββ
|
| 125 |
+
echo ""
|
| 126 |
+
echo "[7/7] Creating output directories ..."
|
| 127 |
+
mkdir -p checkpoints/training
|
| 128 |
+
mkdir -p checkpoints/best
|
| 129 |
+
mkdir -p logs/training
|
| 130 |
+
|
| 131 |
+
# ββ Done βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 132 |
+
echo ""
|
| 133 |
+
echo "============================================================"
|
| 134 |
+
echo " MINDI 1.5 Vision-Coder β MI300X Ready!"
|
| 135 |
+
echo ""
|
| 136 |
+
echo " Project: $(pwd)"
|
| 137 |
+
echo " Data: ${TRAIN_SIZE} train / ${VAL_SIZE} val"
|
| 138 |
+
echo " GPU: $(python -c 'import torch; print(torch.cuda.get_device_name(0))' 2>/dev/null || echo 'N/A')"
|
| 139 |
+
echo ""
|
| 140 |
+
echo " Ready to train!"
|
| 141 |
+
echo " Run: python scripts/train.py --phase 1"
|
| 142 |
+
echo ""
|
| 143 |
+
echo " Or dry run first:"
|
| 144 |
+
echo " Run: python scripts/train.py --dry_run --no_wandb"
|
| 145 |
+
echo "============================================================"
|
src/model/architecture.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MINDI 1.5 Vision-Coder β Model Architecture
|
| 3 |
+
|
| 4 |
+
Loads Qwen/Qwen2.5-Coder-7B-Instruct with LoRA adapters.
|
| 5 |
+
Handles model initialization, LoRA application, save/load,
|
| 6 |
+
and parameter counting for the base LLM component.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Optional
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
from peft import LoraConfig, PeftModel, TaskType, get_peft_model
|
| 16 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class MINDIArchitecture:
|
| 20 |
+
"""Qwen2.5-Coder-7B-Instruct with LoRA for MINDI 1.5 fine-tuning."""
|
| 21 |
+
|
| 22 |
+
DEFAULT_TARGET_MODULES: list[str] = [
|
| 23 |
+
"q_proj", "k_proj", "v_proj", "o_proj",
|
| 24 |
+
"gate_proj", "up_proj", "down_proj",
|
| 25 |
+
]
|
| 26 |
+
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
model_name: str = "Qwen/Qwen2.5-Coder-7B-Instruct",
|
| 30 |
+
device: Optional[str] = None,
|
| 31 |
+
cache_dir: Optional[Path] = None,
|
| 32 |
+
torch_dtype: torch.dtype = torch.bfloat16,
|
| 33 |
+
) -> None:
|
| 34 |
+
"""
|
| 35 |
+
Initialize the architecture wrapper.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
model_name: HuggingFace model identifier.
|
| 39 |
+
device: Target device ('cuda', 'cpu', or None for auto).
|
| 40 |
+
cache_dir: Local directory for model weight cache.
|
| 41 |
+
torch_dtype: Data type for model weights.
|
| 42 |
+
"""
|
| 43 |
+
self.model_name = model_name
|
| 44 |
+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 45 |
+
self.cache_dir = Path(cache_dir) if cache_dir else Path("./checkpoints/base")
|
| 46 |
+
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
| 47 |
+
self.torch_dtype = torch_dtype
|
| 48 |
+
|
| 49 |
+
self.model: Optional[AutoModelForCausalLM] = None
|
| 50 |
+
self.peft_model: Optional[PeftModel] = None
|
| 51 |
+
self.tokenizer: Optional[AutoTokenizer] = None
|
| 52 |
+
|
| 53 |
+
self._load_model()
|
| 54 |
+
|
| 55 |
+
def _load_model(self) -> None:
|
| 56 |
+
"""Load the base model and tokenizer from HuggingFace or cache."""
|
| 57 |
+
print(f"[MINDIArchitecture] Loading {self.model_name} ...")
|
| 58 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 59 |
+
self.model_name,
|
| 60 |
+
cache_dir=str(self.cache_dir),
|
| 61 |
+
torch_dtype=self.torch_dtype,
|
| 62 |
+
device_map="auto" if self.device == "cuda" else None,
|
| 63 |
+
trust_remote_code=True,
|
| 64 |
+
)
|
| 65 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 66 |
+
self.model_name,
|
| 67 |
+
cache_dir=str(self.cache_dir),
|
| 68 |
+
trust_remote_code=True,
|
| 69 |
+
)
|
| 70 |
+
print(f"[MINDIArchitecture] Loaded on {self.device} "
|
| 71 |
+
f"({self._fmt_params(self._total_params())} params)")
|
| 72 |
+
|
| 73 |
+
def apply_lora(
|
| 74 |
+
self,
|
| 75 |
+
r: int = 64,
|
| 76 |
+
lora_alpha: int = 128,
|
| 77 |
+
lora_dropout: float = 0.05,
|
| 78 |
+
target_modules: Optional[list[str]] = None,
|
| 79 |
+
) -> PeftModel:
|
| 80 |
+
"""
|
| 81 |
+
Apply LoRA adapters to the base model.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
r: LoRA rank.
|
| 85 |
+
lora_alpha: LoRA scaling factor.
|
| 86 |
+
lora_dropout: Dropout probability for LoRA layers.
|
| 87 |
+
target_modules: List of module names to apply LoRA to.
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
The PEFT-wrapped model.
|
| 91 |
+
"""
|
| 92 |
+
if self.model is None:
|
| 93 |
+
raise RuntimeError("Base model not loaded.")
|
| 94 |
+
|
| 95 |
+
if target_modules is None:
|
| 96 |
+
target_modules = self.DEFAULT_TARGET_MODULES
|
| 97 |
+
|
| 98 |
+
lora_config = LoraConfig(
|
| 99 |
+
r=r,
|
| 100 |
+
lora_alpha=lora_alpha,
|
| 101 |
+
lora_dropout=lora_dropout,
|
| 102 |
+
target_modules=target_modules,
|
| 103 |
+
bias="none",
|
| 104 |
+
task_type=TaskType.CAUSAL_LM,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
self.peft_model = get_peft_model(self.model, lora_config)
|
| 108 |
+
|
| 109 |
+
info = self.get_trainable_params()
|
| 110 |
+
print(f"[MINDIArchitecture] LoRA applied (r={r}, alpha={lora_alpha})")
|
| 111 |
+
print(f" Trainable: {info['trainable']:>14,} ({info['trainable_pct']:.2f}%)")
|
| 112 |
+
print(f" Frozen: {info['frozen']:>14,}")
|
| 113 |
+
print(f" Total: {info['total']:>14,}")
|
| 114 |
+
|
| 115 |
+
return self.peft_model
|
| 116 |
+
|
| 117 |
+
def get_trainable_params(self) -> dict:
|
| 118 |
+
"""
|
| 119 |
+
Count trainable, frozen, and total parameters.
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
Dictionary with 'trainable', 'frozen', 'total', 'trainable_pct'.
|
| 123 |
+
"""
|
| 124 |
+
model = self.peft_model or self.model
|
| 125 |
+
if model is None:
|
| 126 |
+
return {"trainable": 0, "frozen": 0, "total": 0, "trainable_pct": 0.0}
|
| 127 |
+
|
| 128 |
+
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 129 |
+
total = sum(p.numel() for p in model.parameters())
|
| 130 |
+
frozen = total - trainable
|
| 131 |
+
pct = 100.0 * trainable / total if total > 0 else 0.0
|
| 132 |
+
|
| 133 |
+
return {
|
| 134 |
+
"trainable": trainable,
|
| 135 |
+
"frozen": frozen,
|
| 136 |
+
"total": total,
|
| 137 |
+
"trainable_pct": round(pct, 4),
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
def print_model_info(self) -> None:
|
| 141 |
+
"""Print detailed model architecture and parameter information."""
|
| 142 |
+
model = self.peft_model or self.model
|
| 143 |
+
if model is None:
|
| 144 |
+
print("[MINDIArchitecture] No model loaded.")
|
| 145 |
+
return
|
| 146 |
+
|
| 147 |
+
info = self.get_trainable_params()
|
| 148 |
+
print()
|
| 149 |
+
print("=" * 60)
|
| 150 |
+
print(" MINDI 1.5 β Model Architecture Info")
|
| 151 |
+
print("=" * 60)
|
| 152 |
+
print(f" Base model: {self.model_name}")
|
| 153 |
+
print(f" Device: {self.device}")
|
| 154 |
+
print(f" Dtype: {self.torch_dtype}")
|
| 155 |
+
print(f" LoRA active: {self.peft_model is not None}")
|
| 156 |
+
print(f" Total params: {self._fmt_params(info['total'])}")
|
| 157 |
+
print(f" Trainable: {self._fmt_params(info['trainable'])} "
|
| 158 |
+
f"({info['trainable_pct']:.2f}%)")
|
| 159 |
+
print(f" Frozen: {self._fmt_params(info['frozen'])}")
|
| 160 |
+
|
| 161 |
+
if self.peft_model is not None:
|
| 162 |
+
config = self.peft_model.peft_config.get("default")
|
| 163 |
+
if config is not None:
|
| 164 |
+
print(f" LoRA rank: {config.r}")
|
| 165 |
+
print(f" LoRA alpha: {config.lora_alpha}")
|
| 166 |
+
print(f" LoRA dropout: {config.lora_dropout}")
|
| 167 |
+
print(f" Target modules: {config.target_modules}")
|
| 168 |
+
print("=" * 60)
|
| 169 |
+
print()
|
| 170 |
+
|
| 171 |
+
def save_lora(self, path: Optional[Path] = None) -> Path:
|
| 172 |
+
"""
|
| 173 |
+
Save LoRA adapter weights to disk.
|
| 174 |
+
|
| 175 |
+
Args:
|
| 176 |
+
path: Directory to save to. Defaults to checkpoints/lora.
|
| 177 |
+
|
| 178 |
+
Returns:
|
| 179 |
+
Path where weights were saved.
|
| 180 |
+
"""
|
| 181 |
+
if self.peft_model is None:
|
| 182 |
+
raise RuntimeError("No LoRA adapter to save. Call apply_lora() first.")
|
| 183 |
+
|
| 184 |
+
save_path = Path(path) if path else Path("./checkpoints/lora")
|
| 185 |
+
save_path.mkdir(parents=True, exist_ok=True)
|
| 186 |
+
self.peft_model.save_pretrained(str(save_path))
|
| 187 |
+
print(f"[MINDIArchitecture] LoRA saved to {save_path}")
|
| 188 |
+
return save_path
|
| 189 |
+
|
| 190 |
+
def load_lora(self, path: Path) -> PeftModel:
|
| 191 |
+
"""
|
| 192 |
+
Load LoRA adapter weights from disk.
|
| 193 |
+
|
| 194 |
+
Args:
|
| 195 |
+
path: Directory containing saved adapter weights.
|
| 196 |
+
|
| 197 |
+
Returns:
|
| 198 |
+
The PEFT-wrapped model with loaded adapter.
|
| 199 |
+
"""
|
| 200 |
+
path = Path(path)
|
| 201 |
+
if not path.exists():
|
| 202 |
+
raise FileNotFoundError(f"LoRA adapter not found: {path}")
|
| 203 |
+
if self.model is None:
|
| 204 |
+
raise RuntimeError("Base model not loaded.")
|
| 205 |
+
|
| 206 |
+
self.peft_model = PeftModel.from_pretrained(
|
| 207 |
+
self.model, str(path)
|
| 208 |
+
)
|
| 209 |
+
print(f"[MINDIArchitecture] LoRA loaded from {path}")
|
| 210 |
+
return self.peft_model
|
| 211 |
+
|
| 212 |
+
def resize_embeddings(self, new_vocab_size: int) -> None:
|
| 213 |
+
"""Resize model embeddings for new special tokens."""
|
| 214 |
+
model = self.peft_model or self.model
|
| 215 |
+
if model is None:
|
| 216 |
+
raise RuntimeError("No model loaded.")
|
| 217 |
+
old_size = model.get_input_embeddings().weight.shape[0]
|
| 218 |
+
if new_vocab_size != old_size:
|
| 219 |
+
model.resize_token_embeddings(new_vocab_size)
|
| 220 |
+
print(f"[MINDIArchitecture] Resized embeddings: {old_size} β {new_vocab_size}")
|
| 221 |
+
|
| 222 |
+
def get_model(self) -> AutoModelForCausalLM | PeftModel:
|
| 223 |
+
"""Return the active model (PEFT if LoRA applied, else base)."""
|
| 224 |
+
model = self.peft_model or self.model
|
| 225 |
+
if model is None:
|
| 226 |
+
raise RuntimeError("No model loaded.")
|
| 227 |
+
return model
|
| 228 |
+
|
| 229 |
+
# ββ helpers βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 230 |
+
def _total_params(self) -> int:
|
| 231 |
+
model = self.peft_model or self.model
|
| 232 |
+
if model is None:
|
| 233 |
+
return 0
|
| 234 |
+
return sum(p.numel() for p in model.parameters())
|
| 235 |
+
|
| 236 |
+
@staticmethod
|
| 237 |
+
def _fmt_params(n: int) -> str:
|
| 238 |
+
if n >= 1_000_000_000:
|
| 239 |
+
return f"{n / 1_000_000_000:.2f}B"
|
| 240 |
+
if n >= 1_000_000:
|
| 241 |
+
return f"{n / 1_000_000:.2f}M"
|
| 242 |
+
if n >= 1_000:
|
| 243 |
+
return f"{n / 1_000:.1f}K"
|
| 244 |
+
return str(n)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
# ββ Test block ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 248 |
+
if __name__ == "__main__":
|
| 249 |
+
print("=" * 60)
|
| 250 |
+
print(" MINDI 1.5 β Architecture Test")
|
| 251 |
+
print("=" * 60)
|
| 252 |
+
print()
|
| 253 |
+
|
| 254 |
+
# 1. Load base model
|
| 255 |
+
arch = MINDIArchitecture(
|
| 256 |
+
model_name="Qwen/Qwen2.5-Coder-7B-Instruct",
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
# 2. Apply LoRA
|
| 260 |
+
peft_model = arch.apply_lora(
|
| 261 |
+
r=64,
|
| 262 |
+
lora_alpha=128,
|
| 263 |
+
lora_dropout=0.05,
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
# 3. Print full info
|
| 267 |
+
arch.print_model_info()
|
| 268 |
+
|
| 269 |
+
# 4. Verify trainable params
|
| 270 |
+
info = arch.get_trainable_params()
|
| 271 |
+
assert info["trainable"] > 0, "No trainable parameters!"
|
| 272 |
+
assert info["frozen"] > info["trainable"], "More trainable than frozen β LoRA may not be applied!"
|
| 273 |
+
|
| 274 |
+
# 5. Verify LoRA modules exist
|
| 275 |
+
lora_modules = [name for name, _ in peft_model.named_parameters() if "lora_" in name]
|
| 276 |
+
print(f" LoRA modules found: {len(lora_modules)}")
|
| 277 |
+
assert len(lora_modules) > 0, "No LoRA modules found!"
|
| 278 |
+
|
| 279 |
+
# 6. Quick forward pass test (small input)
|
| 280 |
+
print("\n Running forward pass test ...")
|
| 281 |
+
test_input = arch.tokenizer("Hello MINDI!", return_tensors="pt")
|
| 282 |
+
test_input = {k: v.to(arch.device) for k, v in test_input.items()}
|
| 283 |
+
with torch.no_grad():
|
| 284 |
+
output = peft_model(**test_input)
|
| 285 |
+
print(f" Output logits shape: {output.logits.shape}")
|
| 286 |
+
print(f" Loss: {output.loss}")
|
| 287 |
+
|
| 288 |
+
print("\n β All architecture tests passed!")
|
| 289 |
+
print("=" * 60)
|
src/model/fusion_layer.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MINDI 1.5 Vision-Coder β Vision-Language Fusion Layer
|
| 3 |
+
|
| 4 |
+
Prepends projected visual tokens (256 Γ 4096) to text token embeddings
|
| 5 |
+
and extends the attention mask accordingly. Uses Linear + LayerNorm
|
| 6 |
+
for the visual projection gate.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
from typing import Optional
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class VisionLanguageFusion(nn.Module):
|
| 18 |
+
"""
|
| 19 |
+
Fuses visual and text embeddings by prepending visual tokens.
|
| 20 |
+
|
| 21 |
+
Pipeline:
|
| 22 |
+
1. visual_tokens (batch, 256, 4096) β Linear β LayerNorm
|
| 23 |
+
2. Prepend to text_embeds (batch, seq_len, 4096)
|
| 24 |
+
3. Extend attention_mask to cover the extra 256 visual positions
|
| 25 |
+
|
| 26 |
+
All trainable parameters live in the gate projection + LayerNorm.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
hidden_size: int = 4096,
|
| 32 |
+
num_visual_tokens: int = 256,
|
| 33 |
+
) -> None:
|
| 34 |
+
"""
|
| 35 |
+
Initialize the fusion layer.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
hidden_size: Dimension of both visual and text embeddings (must match).
|
| 39 |
+
num_visual_tokens: Number of visual tokens prepended (default 256).
|
| 40 |
+
"""
|
| 41 |
+
super().__init__()
|
| 42 |
+
self.hidden_size = hidden_size
|
| 43 |
+
self.num_visual_tokens = num_visual_tokens
|
| 44 |
+
|
| 45 |
+
# Gate projection: Linear + LayerNorm to align visual features
|
| 46 |
+
self.gate_proj = nn.Linear(hidden_size, hidden_size)
|
| 47 |
+
self.layer_norm = nn.LayerNorm(hidden_size)
|
| 48 |
+
|
| 49 |
+
def forward(
|
| 50 |
+
self,
|
| 51 |
+
text_embeds: torch.Tensor,
|
| 52 |
+
visual_tokens: Optional[torch.Tensor] = None,
|
| 53 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 54 |
+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 55 |
+
"""
|
| 56 |
+
Fuse visual tokens into text embeddings.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
text_embeds: Text token embeddings (batch, seq_len, hidden_size).
|
| 60 |
+
visual_tokens: Projected visual tokens (batch, 256, hidden_size), or None
|
| 61 |
+
for text-only inputs.
|
| 62 |
+
attention_mask: Text attention mask (batch, seq_len), or None.
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
fused_embeds: (batch, 256 + seq_len, hidden_size) if visual, else unchanged.
|
| 66 |
+
fused_mask: Extended attention mask, or None if input mask was None.
|
| 67 |
+
"""
|
| 68 |
+
# Text-only path β no vision tokens to fuse
|
| 69 |
+
if visual_tokens is None:
|
| 70 |
+
return text_embeds, attention_mask
|
| 71 |
+
|
| 72 |
+
batch_size = text_embeds.shape[0]
|
| 73 |
+
v_batch = visual_tokens.shape[0]
|
| 74 |
+
|
| 75 |
+
# Handle batch size mismatch (single image broadcast to batch)
|
| 76 |
+
if v_batch == 1 and batch_size > 1:
|
| 77 |
+
visual_tokens = visual_tokens.expand(batch_size, -1, -1)
|
| 78 |
+
|
| 79 |
+
# Gate projection + LayerNorm
|
| 80 |
+
gated_visual = self.gate_proj(visual_tokens) # (batch, 256, hidden_size)
|
| 81 |
+
gated_visual = self.layer_norm(gated_visual) # (batch, 256, hidden_size)
|
| 82 |
+
|
| 83 |
+
# Prepend visual tokens to text embeddings
|
| 84 |
+
fused_embeds = torch.cat([gated_visual, text_embeds], dim=1)
|
| 85 |
+
|
| 86 |
+
# Extend attention mask
|
| 87 |
+
fused_mask = self._extend_attention_mask(attention_mask, batch_size, text_embeds.device)
|
| 88 |
+
|
| 89 |
+
return fused_embeds, fused_mask
|
| 90 |
+
|
| 91 |
+
def _extend_attention_mask(
|
| 92 |
+
self,
|
| 93 |
+
attention_mask: Optional[torch.Tensor],
|
| 94 |
+
batch_size: int,
|
| 95 |
+
device: torch.device,
|
| 96 |
+
) -> Optional[torch.Tensor]:
|
| 97 |
+
"""
|
| 98 |
+
Extend attention mask to include visual token positions (all attended).
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
attention_mask: Original text mask (batch, seq_len) or None.
|
| 102 |
+
batch_size: Current batch size.
|
| 103 |
+
device: Target device.
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
Extended mask (batch, 256 + seq_len) or None.
|
| 107 |
+
"""
|
| 108 |
+
if attention_mask is None:
|
| 109 |
+
return None
|
| 110 |
+
|
| 111 |
+
# Visual tokens are always fully attended
|
| 112 |
+
visual_mask = torch.ones(
|
| 113 |
+
batch_size,
|
| 114 |
+
self.num_visual_tokens,
|
| 115 |
+
dtype=attention_mask.dtype,
|
| 116 |
+
device=device,
|
| 117 |
+
)
|
| 118 |
+
return torch.cat([visual_mask, attention_mask], dim=1)
|
| 119 |
+
|
| 120 |
+
def get_trainable_params(self) -> dict:
|
| 121 |
+
"""
|
| 122 |
+
Count trainable parameters in the fusion layer.
|
| 123 |
+
|
| 124 |
+
Returns:
|
| 125 |
+
Dictionary with 'trainable', 'total', and 'trainable_pct'.
|
| 126 |
+
"""
|
| 127 |
+
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 128 |
+
total = sum(p.numel() for p in self.parameters())
|
| 129 |
+
pct = 100.0 * trainable / total if total > 0 else 0.0
|
| 130 |
+
return {
|
| 131 |
+
"trainable": trainable,
|
| 132 |
+
"total": total,
|
| 133 |
+
"trainable_pct": round(pct, 4),
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
def extra_repr(self) -> str:
|
| 137 |
+
return (
|
| 138 |
+
f"hidden_size={self.hidden_size}, "
|
| 139 |
+
f"num_visual_tokens={self.num_visual_tokens}"
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
# ββ Test block ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 144 |
+
if __name__ == "__main__":
|
| 145 |
+
print("=" * 60)
|
| 146 |
+
print(" MINDI 1.5 β Fusion Layer Test")
|
| 147 |
+
print("=" * 60)
|
| 148 |
+
print()
|
| 149 |
+
|
| 150 |
+
BATCH = 2
|
| 151 |
+
SEQ_LEN = 128
|
| 152 |
+
HIDDEN = 4096
|
| 153 |
+
N_VIS = 256
|
| 154 |
+
|
| 155 |
+
fusion = VisionLanguageFusion(hidden_size=HIDDEN, num_visual_tokens=N_VIS)
|
| 156 |
+
print(f" Fusion layer:\n {fusion}\n")
|
| 157 |
+
|
| 158 |
+
# ββ Test 1: Vision + Text fusion βββββββββββββββββββββββββββββ
|
| 159 |
+
print(" Test 1: Vision + Text fusion")
|
| 160 |
+
text_embeds = torch.randn(BATCH, SEQ_LEN, HIDDEN)
|
| 161 |
+
visual_tokens = torch.randn(BATCH, N_VIS, HIDDEN)
|
| 162 |
+
attention_mask = torch.ones(BATCH, SEQ_LEN, dtype=torch.long)
|
| 163 |
+
|
| 164 |
+
fused_embeds, fused_mask = fusion(text_embeds, visual_tokens, attention_mask)
|
| 165 |
+
|
| 166 |
+
expected_seq = N_VIS + SEQ_LEN # 256 + 128 = 384
|
| 167 |
+
assert fused_embeds.shape == (BATCH, expected_seq, HIDDEN), \
|
| 168 |
+
f"Expected ({BATCH}, {expected_seq}, {HIDDEN}), got {fused_embeds.shape}"
|
| 169 |
+
assert fused_mask is not None and fused_mask.shape == (BATCH, expected_seq), \
|
| 170 |
+
f"Expected mask ({BATCH}, {expected_seq}), got {fused_mask.shape}"
|
| 171 |
+
print(f" fused_embeds: {fused_embeds.shape} β")
|
| 172 |
+
print(f" fused_mask: {fused_mask.shape} β")
|
| 173 |
+
|
| 174 |
+
# ββ Test 2: Text-only (no vision) ββββββββββββββββββββββββββββ
|
| 175 |
+
print("\n Test 2: Text-only (no vision)")
|
| 176 |
+
text_only, mask_only = fusion(text_embeds, None, attention_mask)
|
| 177 |
+
assert text_only.shape == (BATCH, SEQ_LEN, HIDDEN)
|
| 178 |
+
assert mask_only is not None and mask_only.shape == (BATCH, SEQ_LEN)
|
| 179 |
+
print(f" text_only: {text_only.shape} β")
|
| 180 |
+
print(f" mask_only: {mask_only.shape} β")
|
| 181 |
+
|
| 182 |
+
# ββ Test 3: No attention mask ββββββββββββββββββββββββββββββββ
|
| 183 |
+
print("\n Test 3: Vision fusion without attention mask")
|
| 184 |
+
fused_no_mask, none_mask = fusion(text_embeds, visual_tokens, None)
|
| 185 |
+
assert fused_no_mask.shape == (BATCH, expected_seq, HIDDEN)
|
| 186 |
+
assert none_mask is None
|
| 187 |
+
print(f" fused_embeds: {fused_no_mask.shape} β")
|
| 188 |
+
print(f" fused_mask: None β")
|
| 189 |
+
|
| 190 |
+
# ββ Test 4: Single-image broadcast βββββββββββββββββββββββββββ
|
| 191 |
+
print("\n Test 4: Single-image broadcast to batch")
|
| 192 |
+
single_visual = torch.randn(1, N_VIS, HIDDEN)
|
| 193 |
+
fused_bc, mask_bc = fusion(text_embeds, single_visual, attention_mask)
|
| 194 |
+
assert fused_bc.shape == (BATCH, expected_seq, HIDDEN)
|
| 195 |
+
print(f" fused_embeds: {fused_bc.shape} β (broadcast 1 β {BATCH})")
|
| 196 |
+
|
| 197 |
+
# ββ Test 5: Trainable params βββββββββββββββββββββββββββββββββ
|
| 198 |
+
print("\n Test 5: Parameter counts")
|
| 199 |
+
info = fusion.get_trainable_params()
|
| 200 |
+
# gate_proj: 4096*4096 + 4096 = 16,781,312
|
| 201 |
+
# layer_norm: 4096 + 4096 = 8,192
|
| 202 |
+
expected_params = HIDDEN * HIDDEN + HIDDEN + HIDDEN + HIDDEN # Linear(w+b) + LN(w+b)
|
| 203 |
+
assert info["trainable"] == expected_params, \
|
| 204 |
+
f"Expected {expected_params}, got {info['trainable']}"
|
| 205 |
+
print(f" Trainable: {info['trainable']:,}")
|
| 206 |
+
print(f" Total: {info['total']:,}")
|
| 207 |
+
print(f" Pct: {info['trainable_pct']}%")
|
| 208 |
+
|
| 209 |
+
# ββ Test 6: Gradient flow ββββββββββββββββββββββββββββββββββββ
|
| 210 |
+
print("\n Test 6: Gradient flow through fusion")
|
| 211 |
+
fusion.zero_grad()
|
| 212 |
+
fused_embeds, _ = fusion(text_embeds, visual_tokens, attention_mask)
|
| 213 |
+
loss = fused_embeds.sum()
|
| 214 |
+
loss.backward()
|
| 215 |
+
assert fusion.gate_proj.weight.grad is not None, "No gradient on gate_proj!"
|
| 216 |
+
assert fusion.layer_norm.weight.grad is not None, "No gradient on layer_norm!"
|
| 217 |
+
print(" gate_proj gradient: β")
|
| 218 |
+
print(" layer_norm gradient: β")
|
| 219 |
+
|
| 220 |
+
print("\n β All fusion layer tests passed!")
|
| 221 |
+
print("=" * 60)
|
src/model/mindi_model.py
ADDED
|
@@ -0,0 +1,620 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MINDI 1.5 Vision-Coder β Complete Model
|
| 3 |
+
|
| 4 |
+
Combines MINDIArchitecture (Qwen2.5-Coder + LoRA), VisionEncoder (CLIP ViT-L/14),
|
| 5 |
+
and VisionLanguageFusion into a single MINDI15 class with forward(), generate(),
|
| 6 |
+
parse_output(), save(), and load() methods.
|
| 7 |
+
|
| 8 |
+
Uses the MINDI custom tokenizer (data/tokenizer/mindi_tokenizer/) with 22 special
|
| 9 |
+
tokens for agentic code generation capabilities.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import re
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from typing import Optional
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
from PIL import Image
|
| 21 |
+
from transformers import AutoTokenizer, PreTrainedTokenizerFast
|
| 22 |
+
|
| 23 |
+
from src.model.architecture import MINDIArchitecture
|
| 24 |
+
from src.model.fusion_layer import VisionLanguageFusion
|
| 25 |
+
from src.model.vision_encoder import VisionEncoder
|
| 26 |
+
|
| 27 |
+
# ββ MINDI special token pairs ββββββββββββββββββββββββββββββββββββββββ
|
| 28 |
+
MINDI_SECTION_TOKENS: dict[str, tuple[str, str]] = {
|
| 29 |
+
"thinking": ("<|think_start|>", "<|think_end|>"),
|
| 30 |
+
"file": ("<|file_start|>", "<|file_end|>"),
|
| 31 |
+
"code": ("<|code_start|>", "<|code_end|>"),
|
| 32 |
+
"critique": ("<|critique_start|>", "<|critique_end|>"),
|
| 33 |
+
"suggest": ("<|suggest_start|>", "<|suggest_end|>"),
|
| 34 |
+
"search": ("<|search_start|>", "<|search_end|>"),
|
| 35 |
+
"error": ("<|error_start|>", "<|error_end|>"),
|
| 36 |
+
"fix": ("<|fix_start|>", "<|fix_end|>"),
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
# Project root (resolved relative to this file)
|
| 40 |
+
PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
|
| 41 |
+
DEFAULT_TOKENIZER_PATH = PROJECT_ROOT / "data" / "tokenizer" / "mindi_tokenizer"
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class MINDI15(nn.Module):
|
| 45 |
+
"""
|
| 46 |
+
MINDI 1.5 Vision-Coder β complete multimodal coding model.
|
| 47 |
+
|
| 48 |
+
Components:
|
| 49 |
+
- architecture: Qwen2.5-Coder-7B-Instruct + LoRA
|
| 50 |
+
- vision_encoder: CLIP ViT-L/14 (frozen) β 256 tokens Γ 4096
|
| 51 |
+
- fusion: Linear + LayerNorm prepend fusion
|
| 52 |
+
- tokenizer: MINDI custom tokenizer with 22 special tokens
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
model_name: str = "Qwen/Qwen2.5-Coder-7B-Instruct",
|
| 58 |
+
clip_model: str = "openai/clip-vit-large-patch14",
|
| 59 |
+
hidden_size: int = 4096,
|
| 60 |
+
num_visual_tokens: int = 256,
|
| 61 |
+
tokenizer_path: Optional[Path] = None,
|
| 62 |
+
device: Optional[str] = None,
|
| 63 |
+
torch_dtype: torch.dtype = torch.bfloat16,
|
| 64 |
+
cache_dir: Optional[Path] = None,
|
| 65 |
+
) -> None:
|
| 66 |
+
"""
|
| 67 |
+
Initialize MINDI 1.5 with all components.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
model_name: HuggingFace base LLM identifier.
|
| 71 |
+
clip_model: HuggingFace CLIP vision model identifier.
|
| 72 |
+
hidden_size: LLM hidden dimension (must match Qwen config).
|
| 73 |
+
num_visual_tokens: Number of visual tokens from CLIP (256).
|
| 74 |
+
tokenizer_path: Path to MINDI custom tokenizer directory.
|
| 75 |
+
device: Target device ('cuda', 'cpu', or None for auto).
|
| 76 |
+
torch_dtype: Data type for model weights.
|
| 77 |
+
cache_dir: Base directory for model weight caches.
|
| 78 |
+
"""
|
| 79 |
+
super().__init__()
|
| 80 |
+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 81 |
+
self.hidden_size = hidden_size
|
| 82 |
+
self.num_visual_tokens = num_visual_tokens
|
| 83 |
+
self.torch_dtype = torch_dtype
|
| 84 |
+
|
| 85 |
+
cache_base = Path(cache_dir) if cache_dir else PROJECT_ROOT / "checkpoints"
|
| 86 |
+
|
| 87 |
+
print("=" * 60)
|
| 88 |
+
print(" MINDI 1.5 Vision-Coder β Initializing")
|
| 89 |
+
print("=" * 60)
|
| 90 |
+
|
| 91 |
+
# 1. Load MINDI custom tokenizer (NOT the base Qwen tokenizer)
|
| 92 |
+
tok_path = Path(tokenizer_path) if tokenizer_path else DEFAULT_TOKENIZER_PATH
|
| 93 |
+
print(f"\n[MINDI15] Loading MINDI tokenizer from {tok_path} ...")
|
| 94 |
+
self.tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained(
|
| 95 |
+
str(tok_path),
|
| 96 |
+
trust_remote_code=True,
|
| 97 |
+
)
|
| 98 |
+
print(f" Vocab size: {len(self.tokenizer)}")
|
| 99 |
+
|
| 100 |
+
# 2. LLM backbone with LoRA
|
| 101 |
+
self.architecture = MINDIArchitecture(
|
| 102 |
+
model_name=model_name,
|
| 103 |
+
device=self.device,
|
| 104 |
+
cache_dir=cache_base / "base",
|
| 105 |
+
torch_dtype=torch_dtype,
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# Resize embeddings to match MINDI tokenizer (includes 22 special tokens)
|
| 109 |
+
self.architecture.resize_embeddings(len(self.tokenizer))
|
| 110 |
+
|
| 111 |
+
# Apply LoRA
|
| 112 |
+
self.architecture.apply_lora()
|
| 113 |
+
|
| 114 |
+
# 3. Vision encoder (frozen CLIP + trainable projection)
|
| 115 |
+
self.vision_encoder = VisionEncoder(
|
| 116 |
+
model_name=clip_model,
|
| 117 |
+
llm_hidden_size=hidden_size,
|
| 118 |
+
device=self.device,
|
| 119 |
+
cache_dir=cache_base / "vision",
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
# 4. Fusion layer
|
| 123 |
+
self.fusion = VisionLanguageFusion(
|
| 124 |
+
hidden_size=hidden_size,
|
| 125 |
+
num_visual_tokens=num_visual_tokens,
|
| 126 |
+
)
|
| 127 |
+
self.fusion.to(self.device)
|
| 128 |
+
|
| 129 |
+
# Cache special token IDs
|
| 130 |
+
self._special_ids: dict[str, int] = {}
|
| 131 |
+
for section, (start_tok, end_tok) in MINDI_SECTION_TOKENS.items():
|
| 132 |
+
sid = self.tokenizer.convert_tokens_to_ids(start_tok)
|
| 133 |
+
eid = self.tokenizer.convert_tokens_to_ids(end_tok)
|
| 134 |
+
self._special_ids[f"{section}_start"] = sid
|
| 135 |
+
self._special_ids[f"{section}_end"] = eid
|
| 136 |
+
|
| 137 |
+
self._print_summary()
|
| 138 |
+
|
| 139 |
+
def _print_summary(self) -> None:
|
| 140 |
+
"""Print initialization summary."""
|
| 141 |
+
llm_info = self.architecture.get_trainable_params()
|
| 142 |
+
vis_info = {
|
| 143 |
+
"trainable": sum(p.numel() for p in self.vision_encoder.parameters() if p.requires_grad),
|
| 144 |
+
"total": sum(p.numel() for p in self.vision_encoder.parameters()),
|
| 145 |
+
}
|
| 146 |
+
fus_info = self.fusion.get_trainable_params()
|
| 147 |
+
|
| 148 |
+
total_trainable = llm_info["trainable"] + vis_info["trainable"] + fus_info["trainable"]
|
| 149 |
+
total_all = llm_info["total"] + vis_info["total"] + fus_info["total"]
|
| 150 |
+
|
| 151 |
+
print()
|
| 152 |
+
print("=" * 60)
|
| 153 |
+
print(" MINDI 1.5 β Initialization Complete")
|
| 154 |
+
print("=" * 60)
|
| 155 |
+
print(f" LLM trainable (LoRA): {llm_info['trainable']:>14,}")
|
| 156 |
+
print(f" Vision trainable: {vis_info['trainable']:>14,}")
|
| 157 |
+
print(f" Fusion trainable: {fus_info['trainable']:>14,}")
|
| 158 |
+
print(f" βββββββββββββββββββββββββββββββββββββ")
|
| 159 |
+
print(f" Total trainable: {total_trainable:>14,}")
|
| 160 |
+
print(f" Total params: {total_all:>14,}")
|
| 161 |
+
print(f" Tokenizer vocab: {len(self.tokenizer):>14,}")
|
| 162 |
+
print("=" * 60)
|
| 163 |
+
print()
|
| 164 |
+
|
| 165 |
+
# ββ Forward pass ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 166 |
+
|
| 167 |
+
def forward(
|
| 168 |
+
self,
|
| 169 |
+
input_ids: torch.Tensor,
|
| 170 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 171 |
+
labels: Optional[torch.Tensor] = None,
|
| 172 |
+
image: Optional[Image.Image] = None,
|
| 173 |
+
) -> dict:
|
| 174 |
+
"""
|
| 175 |
+
Forward pass with optional vision input.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
input_ids: Token IDs (batch, seq_len).
|
| 179 |
+
attention_mask: Attention mask (batch, seq_len).
|
| 180 |
+
labels: Target token IDs for loss computation (batch, seq_len).
|
| 181 |
+
image: Optional PIL image for multimodal input.
|
| 182 |
+
|
| 183 |
+
Returns:
|
| 184 |
+
Dict with 'loss', 'logits', and optionally 'visual_tokens'.
|
| 185 |
+
"""
|
| 186 |
+
model = self.architecture.get_model()
|
| 187 |
+
|
| 188 |
+
# Get text embeddings from the LLM's embedding layer
|
| 189 |
+
text_embeds = model.get_input_embeddings()(input_ids)
|
| 190 |
+
|
| 191 |
+
# Encode vision if image provided
|
| 192 |
+
visual_tokens = None
|
| 193 |
+
if image is not None:
|
| 194 |
+
visual_tokens = self.vision_encoder.encode_image(image)
|
| 195 |
+
|
| 196 |
+
# Fuse vision + text
|
| 197 |
+
fused_embeds, fused_mask = self.fusion(text_embeds, visual_tokens, attention_mask)
|
| 198 |
+
|
| 199 |
+
# Extend labels if vision tokens were prepended
|
| 200 |
+
if visual_tokens is not None and labels is not None:
|
| 201 |
+
batch_size = labels.shape[0]
|
| 202 |
+
# -100 = ignore index for cross-entropy on visual positions
|
| 203 |
+
visual_labels = torch.full(
|
| 204 |
+
(batch_size, self.num_visual_tokens),
|
| 205 |
+
fill_value=-100,
|
| 206 |
+
dtype=labels.dtype,
|
| 207 |
+
device=labels.device,
|
| 208 |
+
)
|
| 209 |
+
labels = torch.cat([visual_labels, labels], dim=1)
|
| 210 |
+
|
| 211 |
+
# Forward through LLM with embeddings (bypass tokenization)
|
| 212 |
+
outputs = model(
|
| 213 |
+
inputs_embeds=fused_embeds,
|
| 214 |
+
attention_mask=fused_mask,
|
| 215 |
+
labels=labels,
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
result = {
|
| 219 |
+
"loss": outputs.loss,
|
| 220 |
+
"logits": outputs.logits,
|
| 221 |
+
}
|
| 222 |
+
if visual_tokens is not None:
|
| 223 |
+
result["visual_tokens"] = visual_tokens
|
| 224 |
+
|
| 225 |
+
return result
|
| 226 |
+
|
| 227 |
+
# ββ Generation ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 228 |
+
|
| 229 |
+
@torch.no_grad()
|
| 230 |
+
def generate(
|
| 231 |
+
self,
|
| 232 |
+
prompt: str,
|
| 233 |
+
image: Optional[Image.Image] = None,
|
| 234 |
+
max_new_tokens: int = 2048,
|
| 235 |
+
temperature: float = 0.7,
|
| 236 |
+
top_p: float = 0.9,
|
| 237 |
+
top_k: int = 50,
|
| 238 |
+
do_sample: bool = True,
|
| 239 |
+
repetition_penalty: float = 1.1,
|
| 240 |
+
) -> str:
|
| 241 |
+
"""
|
| 242 |
+
Generate text from a prompt, optionally conditioned on an image.
|
| 243 |
+
|
| 244 |
+
Uses the MINDI custom tokenizer (with special tokens) for both
|
| 245 |
+
encoding the prompt and decoding the output.
|
| 246 |
+
|
| 247 |
+
Args:
|
| 248 |
+
prompt: Input text prompt.
|
| 249 |
+
image: Optional PIL image for multimodal generation.
|
| 250 |
+
max_new_tokens: Maximum tokens to generate.
|
| 251 |
+
temperature: Sampling temperature.
|
| 252 |
+
top_p: Nucleus sampling threshold.
|
| 253 |
+
top_k: Top-k sampling threshold.
|
| 254 |
+
do_sample: Whether to sample (False = greedy).
|
| 255 |
+
repetition_penalty: Penalty for repeated tokens.
|
| 256 |
+
|
| 257 |
+
Returns:
|
| 258 |
+
Generated text string (decoded with MINDI tokenizer).
|
| 259 |
+
"""
|
| 260 |
+
model = self.architecture.get_model()
|
| 261 |
+
model.eval()
|
| 262 |
+
|
| 263 |
+
# Tokenize with MINDI tokenizer
|
| 264 |
+
inputs = self.tokenizer(prompt, return_tensors="pt")
|
| 265 |
+
input_ids = inputs["input_ids"].to(self.device)
|
| 266 |
+
attention_mask = inputs["attention_mask"].to(self.device)
|
| 267 |
+
|
| 268 |
+
# If image provided, build fused embeddings
|
| 269 |
+
if image is not None:
|
| 270 |
+
text_embeds = model.get_input_embeddings()(input_ids)
|
| 271 |
+
visual_tokens = self.vision_encoder.encode_image(image)
|
| 272 |
+
fused_embeds, fused_mask = self.fusion(text_embeds, visual_tokens, attention_mask)
|
| 273 |
+
|
| 274 |
+
output_ids = model.generate(
|
| 275 |
+
inputs_embeds=fused_embeds,
|
| 276 |
+
attention_mask=fused_mask,
|
| 277 |
+
max_new_tokens=max_new_tokens,
|
| 278 |
+
temperature=temperature,
|
| 279 |
+
top_p=top_p,
|
| 280 |
+
top_k=top_k,
|
| 281 |
+
do_sample=do_sample,
|
| 282 |
+
repetition_penalty=repetition_penalty,
|
| 283 |
+
pad_token_id=self.tokenizer.pad_token_id or self.tokenizer.eos_token_id,
|
| 284 |
+
)
|
| 285 |
+
else:
|
| 286 |
+
# Text-only generation (direct input_ids)
|
| 287 |
+
output_ids = model.generate(
|
| 288 |
+
input_ids=input_ids,
|
| 289 |
+
attention_mask=attention_mask,
|
| 290 |
+
max_new_tokens=max_new_tokens,
|
| 291 |
+
temperature=temperature,
|
| 292 |
+
top_p=top_p,
|
| 293 |
+
top_k=top_k,
|
| 294 |
+
do_sample=do_sample,
|
| 295 |
+
repetition_penalty=repetition_penalty,
|
| 296 |
+
pad_token_id=self.tokenizer.pad_token_id or self.tokenizer.eos_token_id,
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
# Decode only the newly generated tokens
|
| 300 |
+
generated_ids = output_ids[:, input_ids.shape[1]:]
|
| 301 |
+
text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=False)
|
| 302 |
+
return text.strip()
|
| 303 |
+
|
| 304 |
+
# ββ Output parsing ββββββββββββββββββββββββββββββββββββββββββββ
|
| 305 |
+
|
| 306 |
+
@staticmethod
|
| 307 |
+
def parse_output(text: str) -> dict[str, list[str]]:
|
| 308 |
+
"""
|
| 309 |
+
Parse generated text and extract ALL MINDI special-token sections.
|
| 310 |
+
|
| 311 |
+
Extracts content between each pair of special tokens:
|
| 312 |
+
<|think_start|> ... <|think_end|> β "thinking"
|
| 313 |
+
<|file_start|> ... <|file_end|> β "file"
|
| 314 |
+
<|code_start|> ... <|code_end|> β "code"
|
| 315 |
+
<|critique_start|> ... <|critique_end|> β "critique"
|
| 316 |
+
<|suggest_start|> ... <|suggest_end|> β "suggest"
|
| 317 |
+
<|search_start|> ... <|search_end|> β "search"
|
| 318 |
+
<|error_start|> ... <|error_end|> β "error"
|
| 319 |
+
<|fix_start|> ... <|fix_end|> β "fix"
|
| 320 |
+
|
| 321 |
+
Each section may appear multiple times; all occurrences are captured.
|
| 322 |
+
|
| 323 |
+
Args:
|
| 324 |
+
text: Raw generated text potentially containing special tokens.
|
| 325 |
+
|
| 326 |
+
Returns:
|
| 327 |
+
Dict mapping section name β list of extracted content strings.
|
| 328 |
+
Empty list if section not found. Also includes "raw" with full text.
|
| 329 |
+
"""
|
| 330 |
+
result: dict[str, list[str]] = {"raw": [text]}
|
| 331 |
+
|
| 332 |
+
for section, (start_tok, end_tok) in MINDI_SECTION_TOKENS.items():
|
| 333 |
+
# Escape the pipe characters for regex
|
| 334 |
+
pattern = re.escape(start_tok) + r"(.*?)" + re.escape(end_tok)
|
| 335 |
+
matches = re.findall(pattern, text, flags=re.DOTALL)
|
| 336 |
+
result[section] = [m.strip() for m in matches]
|
| 337 |
+
|
| 338 |
+
return result
|
| 339 |
+
|
| 340 |
+
# ββ Phase control (for 3-phase training) ββββββββββββββββββββββ
|
| 341 |
+
|
| 342 |
+
def set_trainable_components(
|
| 343 |
+
self,
|
| 344 |
+
lora: bool = False,
|
| 345 |
+
vision_projection: bool = False,
|
| 346 |
+
fusion: bool = False,
|
| 347 |
+
) -> dict[str, int]:
|
| 348 |
+
"""
|
| 349 |
+
Enable/disable training for specific components.
|
| 350 |
+
|
| 351 |
+
Used by the trainer to implement 3-phase training:
|
| 352 |
+
Phase 1: lora=True, vision_projection=False, fusion=False
|
| 353 |
+
Phase 2: lora=False, vision_projection=True, fusion=True
|
| 354 |
+
Phase 3: lora=True, vision_projection=True, fusion=True
|
| 355 |
+
|
| 356 |
+
Args:
|
| 357 |
+
lora: Whether LoRA adapter parameters should be trainable.
|
| 358 |
+
vision_projection: Whether the vision projection layer should train.
|
| 359 |
+
fusion: Whether the fusion layer should be trainable.
|
| 360 |
+
|
| 361 |
+
Returns:
|
| 362 |
+
Dict with trainable param counts per component.
|
| 363 |
+
"""
|
| 364 |
+
counts = {}
|
| 365 |
+
|
| 366 |
+
# LoRA parameters
|
| 367 |
+
peft_model = self.architecture.peft_model
|
| 368 |
+
if peft_model is not None:
|
| 369 |
+
for name, param in peft_model.named_parameters():
|
| 370 |
+
if "lora_" in name:
|
| 371 |
+
param.requires_grad = lora
|
| 372 |
+
counts["lora"] = sum(
|
| 373 |
+
p.numel() for n, p in (peft_model or self.architecture.model).named_parameters()
|
| 374 |
+
if "lora_" in n and p.requires_grad
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
# Vision projection
|
| 378 |
+
for param in self.vision_encoder.projection.parameters():
|
| 379 |
+
param.requires_grad = vision_projection
|
| 380 |
+
counts["vision_projection"] = sum(
|
| 381 |
+
p.numel() for p in self.vision_encoder.projection.parameters() if p.requires_grad
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
# Fusion layer
|
| 385 |
+
for param in self.fusion.parameters():
|
| 386 |
+
param.requires_grad = fusion
|
| 387 |
+
counts["fusion"] = sum(
|
| 388 |
+
p.numel() for p in self.fusion.parameters() if p.requires_grad
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
counts["total_trainable"] = counts["lora"] + counts["vision_projection"] + counts["fusion"]
|
| 392 |
+
|
| 393 |
+
print(f"[MINDI15] Trainable: LoRA={counts['lora']:,} | "
|
| 394 |
+
f"VisionProj={counts['vision_projection']:,} | "
|
| 395 |
+
f"Fusion={counts['fusion']:,} | "
|
| 396 |
+
f"Total={counts['total_trainable']:,}")
|
| 397 |
+
|
| 398 |
+
return counts
|
| 399 |
+
|
| 400 |
+
# ββ Save / Load βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 401 |
+
|
| 402 |
+
def save(self, save_dir: Optional[Path] = None) -> Path:
|
| 403 |
+
"""
|
| 404 |
+
Save all trainable weights (LoRA + vision projection + fusion).
|
| 405 |
+
|
| 406 |
+
Args:
|
| 407 |
+
save_dir: Root directory for saving. Defaults to checkpoints/mindi15.
|
| 408 |
+
|
| 409 |
+
Returns:
|
| 410 |
+
Path to save directory.
|
| 411 |
+
"""
|
| 412 |
+
save_path = Path(save_dir) if save_dir else PROJECT_ROOT / "checkpoints" / "mindi15"
|
| 413 |
+
save_path.mkdir(parents=True, exist_ok=True)
|
| 414 |
+
|
| 415 |
+
# LoRA adapter
|
| 416 |
+
self.architecture.save_lora(save_path / "lora")
|
| 417 |
+
|
| 418 |
+
# Vision projection
|
| 419 |
+
self.vision_encoder.save_projection(save_path / "vision")
|
| 420 |
+
|
| 421 |
+
# Fusion layer
|
| 422 |
+
fusion_path = save_path / "fusion"
|
| 423 |
+
fusion_path.mkdir(parents=True, exist_ok=True)
|
| 424 |
+
torch.save(self.fusion.state_dict(), fusion_path / "fusion.pt")
|
| 425 |
+
|
| 426 |
+
print(f"[MINDI15] All weights saved to {save_path}")
|
| 427 |
+
return save_path
|
| 428 |
+
|
| 429 |
+
def load(self, load_dir: Path) -> None:
|
| 430 |
+
"""
|
| 431 |
+
Load all trainable weights (LoRA + vision projection + fusion).
|
| 432 |
+
|
| 433 |
+
Args:
|
| 434 |
+
load_dir: Root directory containing saved weights.
|
| 435 |
+
"""
|
| 436 |
+
load_path = Path(load_dir)
|
| 437 |
+
if not load_path.exists():
|
| 438 |
+
raise FileNotFoundError(f"Checkpoint not found: {load_path}")
|
| 439 |
+
|
| 440 |
+
# LoRA adapter
|
| 441 |
+
lora_path = load_path / "lora"
|
| 442 |
+
if lora_path.exists():
|
| 443 |
+
self.architecture.load_lora(lora_path)
|
| 444 |
+
|
| 445 |
+
# Vision projection
|
| 446 |
+
vision_path = load_path / "vision"
|
| 447 |
+
if vision_path.exists():
|
| 448 |
+
self.vision_encoder.load_projection(vision_path)
|
| 449 |
+
|
| 450 |
+
# Fusion layer
|
| 451 |
+
fusion_file = load_path / "fusion" / "fusion.pt"
|
| 452 |
+
if fusion_file.exists():
|
| 453 |
+
state_dict = torch.load(fusion_file, map_location=self.device, weights_only=True)
|
| 454 |
+
self.fusion.load_state_dict(state_dict)
|
| 455 |
+
print(f"[MINDI15] Fusion loaded from {fusion_file.parent}")
|
| 456 |
+
|
| 457 |
+
print(f"[MINDI15] All weights loaded from {load_path}")
|
| 458 |
+
|
| 459 |
+
# ββ Utilities βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 460 |
+
|
| 461 |
+
def get_all_trainable_params(self) -> dict:
|
| 462 |
+
"""Get combined trainable parameter counts across all components."""
|
| 463 |
+
llm = self.architecture.get_trainable_params()
|
| 464 |
+
vis_trainable = sum(
|
| 465 |
+
p.numel() for p in self.vision_encoder.parameters() if p.requires_grad
|
| 466 |
+
)
|
| 467 |
+
fus = self.fusion.get_trainable_params()
|
| 468 |
+
|
| 469 |
+
total_trainable = llm["trainable"] + vis_trainable + fus["trainable"]
|
| 470 |
+
total_all = llm["total"] + sum(p.numel() for p in self.vision_encoder.parameters()) + fus["total"]
|
| 471 |
+
|
| 472 |
+
return {
|
| 473 |
+
"llm_trainable": llm["trainable"],
|
| 474 |
+
"llm_total": llm["total"],
|
| 475 |
+
"vision_trainable": vis_trainable,
|
| 476 |
+
"fusion_trainable": fus["trainable"],
|
| 477 |
+
"total_trainable": total_trainable,
|
| 478 |
+
"total_params": total_all,
|
| 479 |
+
"trainable_pct": round(100.0 * total_trainable / total_all, 4) if total_all > 0 else 0.0,
|
| 480 |
+
}
|
| 481 |
+
|
| 482 |
+
def print_info(self) -> None:
|
| 483 |
+
"""Print complete model information."""
|
| 484 |
+
self.architecture.print_model_info()
|
| 485 |
+
info = self.get_all_trainable_params()
|
| 486 |
+
print(" MINDI 1.5 Combined Trainable Parameters:")
|
| 487 |
+
print(f" LLM (LoRA): {info['llm_trainable']:>14,}")
|
| 488 |
+
print(f" Vision proj: {info['vision_trainable']:>14,}")
|
| 489 |
+
print(f" Fusion: {info['fusion_trainable']:>14,}")
|
| 490 |
+
print(f" Total trainable: {info['total_trainable']:>14,}")
|
| 491 |
+
print(f" Total params: {info['total_params']:>14,}")
|
| 492 |
+
print(f" Trainable %: {info['trainable_pct']:>13.2f}%")
|
| 493 |
+
print()
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
# ββ Test block ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 497 |
+
if __name__ == "__main__":
|
| 498 |
+
print("=" * 60)
|
| 499 |
+
print(" MINDI 1.5 β Complete Model Test")
|
| 500 |
+
print("=" * 60)
|
| 501 |
+
print()
|
| 502 |
+
|
| 503 |
+
# ββ Test 1: parse_output (no GPU needed) βββββββββββββββββββββ
|
| 504 |
+
print(" Test 1: parse_output()")
|
| 505 |
+
sample_output = (
|
| 506 |
+
"<|think_start|>The user wants a Python function.<|think_end|>"
|
| 507 |
+
"<|file_start|>main.py<|file_end|>"
|
| 508 |
+
"<|code_start|>def hello():\n print('Hello MINDI!')<|code_end|>"
|
| 509 |
+
"<|critique_start|>Missing type hints and docstring.<|critique_end|>"
|
| 510 |
+
"<|suggest_start|>Add return type annotation.<|suggest_end|>"
|
| 511 |
+
"<|search_start|>python type hints best practices<|search_end|>"
|
| 512 |
+
"<|error_start|>NameError: name 'x' is not defined<|error_end|>"
|
| 513 |
+
"<|fix_start|>Add x = 0 before the loop.<|fix_end|>"
|
| 514 |
+
"<|think_start|>Let me also add error handling.<|think_end|>"
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
parsed = MINDI15.parse_output(sample_output)
|
| 518 |
+
|
| 519 |
+
assert len(parsed["thinking"]) == 2, f"Expected 2 thinking sections, got {len(parsed['thinking'])}"
|
| 520 |
+
assert parsed["thinking"][0] == "The user wants a Python function."
|
| 521 |
+
assert parsed["thinking"][1] == "Let me also add error handling."
|
| 522 |
+
assert parsed["file"] == ["main.py"]
|
| 523 |
+
assert parsed["code"] == ["def hello():\n print('Hello MINDI!')"]
|
| 524 |
+
assert parsed["critique"] == ["Missing type hints and docstring."]
|
| 525 |
+
assert parsed["suggest"] == ["Add return type annotation."]
|
| 526 |
+
assert parsed["search"] == ["python type hints best practices"]
|
| 527 |
+
assert parsed["error"] == ["NameError: name 'x' is not defined"]
|
| 528 |
+
assert parsed["fix"] == ["Add x = 0 before the loop."]
|
| 529 |
+
assert "raw" in parsed
|
| 530 |
+
print(" All 8 section types extracted correctly β")
|
| 531 |
+
print(f" Sections found: {[k for k, v in parsed.items() if k != 'raw' and v]}")
|
| 532 |
+
|
| 533 |
+
# ββ Test 2: parse_output with missing sections βββββββββββββββ
|
| 534 |
+
print("\n Test 2: parse_output() with partial output")
|
| 535 |
+
partial = "<|code_start|>print('hi')<|code_end|>"
|
| 536 |
+
parsed2 = MINDI15.parse_output(partial)
|
| 537 |
+
assert parsed2["code"] == ["print('hi')"]
|
| 538 |
+
assert parsed2["thinking"] == []
|
| 539 |
+
assert parsed2["file"] == []
|
| 540 |
+
assert parsed2["fix"] == []
|
| 541 |
+
print(" Missing sections return empty lists β")
|
| 542 |
+
|
| 543 |
+
# ββ Test 3: parse_output with empty input ββββββββββββββββββββ
|
| 544 |
+
print("\n Test 3: parse_output() with empty string")
|
| 545 |
+
parsed3 = MINDI15.parse_output("")
|
| 546 |
+
assert all(v == [] for k, v in parsed3.items() if k != "raw")
|
| 547 |
+
print(" Empty input returns all empty lists β")
|
| 548 |
+
|
| 549 |
+
# ββ Test 4: Verify MINDI_SECTION_TOKENS covers all 8 ββββββββ
|
| 550 |
+
print("\n Test 4: Token coverage")
|
| 551 |
+
expected_sections = {"thinking", "file", "code", "critique", "suggest", "search", "error", "fix"}
|
| 552 |
+
assert set(MINDI_SECTION_TOKENS.keys()) == expected_sections
|
| 553 |
+
print(f" All 8 sections defined: {sorted(expected_sections)} β")
|
| 554 |
+
|
| 555 |
+
# ββ GPU-dependent tests (skip if no CUDA) ββββββββββββββββββββ
|
| 556 |
+
if torch.cuda.is_available():
|
| 557 |
+
print("\n Test 5: Full model initialization (GPU)")
|
| 558 |
+
model = MINDI15()
|
| 559 |
+
model.print_info()
|
| 560 |
+
|
| 561 |
+
# Test set_trainable_components (Phase 1)
|
| 562 |
+
print("\n Test 6: Phase 1 β LoRA only")
|
| 563 |
+
counts = model.set_trainable_components(lora=True, vision_projection=False, fusion=False)
|
| 564 |
+
assert counts["lora"] > 0
|
| 565 |
+
assert counts["vision_projection"] == 0
|
| 566 |
+
assert counts["fusion"] == 0
|
| 567 |
+
|
| 568 |
+
# Test set_trainable_components (Phase 2)
|
| 569 |
+
print("\n Test 7: Phase 2 β Vision bridge only")
|
| 570 |
+
counts = model.set_trainable_components(lora=False, vision_projection=True, fusion=True)
|
| 571 |
+
assert counts["lora"] == 0
|
| 572 |
+
assert counts["vision_projection"] > 0
|
| 573 |
+
assert counts["fusion"] > 0
|
| 574 |
+
|
| 575 |
+
# Test set_trainable_components (Phase 3)
|
| 576 |
+
print("\n Test 8: Phase 3 β All trainable")
|
| 577 |
+
counts = model.set_trainable_components(lora=True, vision_projection=True, fusion=True)
|
| 578 |
+
assert counts["lora"] > 0
|
| 579 |
+
assert counts["vision_projection"] > 0
|
| 580 |
+
assert counts["fusion"] > 0
|
| 581 |
+
|
| 582 |
+
# Test forward (text only)
|
| 583 |
+
print("\n Test 9: Forward pass (text only)")
|
| 584 |
+
tokens = model.tokenizer("Hello MINDI!", return_tensors="pt")
|
| 585 |
+
input_ids = tokens["input_ids"].to(model.device)
|
| 586 |
+
attn_mask = tokens["attention_mask"].to(model.device)
|
| 587 |
+
result = model(input_ids=input_ids, attention_mask=attn_mask, labels=input_ids)
|
| 588 |
+
assert result["loss"] is not None
|
| 589 |
+
print(f" Loss: {result['loss'].item():.4f}")
|
| 590 |
+
print(f" Logits: {result['logits'].shape}")
|
| 591 |
+
|
| 592 |
+
# Test forward (with image)
|
| 593 |
+
print("\n Test 10: Forward pass (with dummy image)")
|
| 594 |
+
dummy_img = Image.new("RGB", (224, 224), color=(100, 150, 200))
|
| 595 |
+
result_v = model(input_ids=input_ids, attention_mask=attn_mask, labels=input_ids, image=dummy_img)
|
| 596 |
+
assert result_v["loss"] is not None
|
| 597 |
+
assert "visual_tokens" in result_v
|
| 598 |
+
print(f" Loss: {result_v['loss'].item():.4f}")
|
| 599 |
+
print(f" Visual tokens: {result_v['visual_tokens'].shape}")
|
| 600 |
+
|
| 601 |
+
# Test generate (text only)
|
| 602 |
+
print("\n Test 11: Generate (text only, short)")
|
| 603 |
+
output = model.generate("Write a hello world in Python:", max_new_tokens=50)
|
| 604 |
+
print(f" Output: {output[:100]}...")
|
| 605 |
+
|
| 606 |
+
print("\n Test 12: Save/load round-trip")
|
| 607 |
+
import tempfile
|
| 608 |
+
with tempfile.TemporaryDirectory() as tmp:
|
| 609 |
+
model.save(Path(tmp))
|
| 610 |
+
# Verify files exist
|
| 611 |
+
assert (Path(tmp) / "lora").exists()
|
| 612 |
+
assert (Path(tmp) / "vision" / "projection.pt").exists()
|
| 613 |
+
assert (Path(tmp) / "fusion" / "fusion.pt").exists()
|
| 614 |
+
print(" Save β")
|
| 615 |
+
else:
|
| 616 |
+
print("\n [SKIP] GPU tests (no CUDA available)")
|
| 617 |
+
print(" Tests 5-12 require GPU with ~20GB VRAM")
|
| 618 |
+
|
| 619 |
+
print("\n β All MINDI 1.5 model tests passed!")
|
| 620 |
+
print("=" * 60)
|
src/model/vision_encoder.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
| 1 |
"""
|
| 2 |
MINDI 1.5 Vision-Coder β Vision Encoder
|
| 3 |
|
| 4 |
-
Uses CLIP ViT-L/14 to encode UI screenshots into
|
| 5 |
-
|
|
|
|
| 6 |
"""
|
| 7 |
|
| 8 |
from __future__ import annotations
|
|
@@ -13,79 +14,237 @@ from typing import Optional
|
|
| 13 |
import torch
|
| 14 |
import torch.nn as nn
|
| 15 |
from PIL import Image
|
| 16 |
-
from transformers import
|
| 17 |
|
| 18 |
|
| 19 |
class VisionEncoder(nn.Module):
|
| 20 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
def __init__(
|
| 23 |
self,
|
| 24 |
model_name: str = "openai/clip-vit-large-patch14",
|
| 25 |
-
|
| 26 |
device: Optional[str] = None,
|
| 27 |
cache_dir: Optional[Path] = None,
|
|
|
|
| 28 |
) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
super().__init__()
|
|
|
|
|
|
|
| 30 |
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 31 |
-
self.cache_dir = cache_dir
|
| 32 |
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
| 33 |
|
| 34 |
-
# Load CLIP model and processor
|
| 35 |
-
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
| 37 |
)
|
| 38 |
-
self.
|
| 39 |
-
model_name,
|
|
|
|
| 40 |
)
|
| 41 |
|
| 42 |
-
# Freeze CLIP backbone
|
| 43 |
for param in self.clip.parameters():
|
| 44 |
param.requires_grad = False
|
|
|
|
| 45 |
|
| 46 |
-
# Trainable projection: CLIP hidden β LLM
|
| 47 |
-
clip_hidden_size: int = self.clip.config.
|
| 48 |
-
self.projection = nn.
|
| 49 |
-
nn.Linear(clip_hidden_size, projection_dim),
|
| 50 |
-
nn.GELU(),
|
| 51 |
-
nn.Linear(projection_dim, projection_dim),
|
| 52 |
-
)
|
| 53 |
|
| 54 |
self.to(self.device)
|
| 55 |
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
with torch.no_grad():
|
| 62 |
-
vision_outputs = self.clip
|
| 63 |
-
#
|
| 64 |
-
|
| 65 |
|
| 66 |
-
# Project into LLM embedding space (
|
| 67 |
-
projected = self.projection(
|
| 68 |
return projected
|
| 69 |
|
| 70 |
-
def
|
| 71 |
-
"""
|
| 72 |
-
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
return self.encode_image(image)
|
| 77 |
|
| 78 |
def save_projection(self, save_dir: Optional[Path] = None) -> Path:
|
| 79 |
-
"""
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
save_path.mkdir(parents=True, exist_ok=True)
|
| 82 |
torch.save(self.projection.state_dict(), save_path / "projection.pt")
|
|
|
|
| 83 |
return save_path
|
| 84 |
|
| 85 |
def load_projection(self, load_dir: Path) -> None:
|
| 86 |
-
"""
|
| 87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
if not weights_path.exists():
|
| 89 |
raise FileNotFoundError(f"Projection weights not found: {weights_path}")
|
| 90 |
state_dict = torch.load(weights_path, map_location=self.device, weights_only=True)
|
| 91 |
self.projection.load_state_dict(state_dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
MINDI 1.5 Vision-Coder β Vision Encoder
|
| 3 |
|
| 4 |
+
Uses CLIP ViT-L/14 (frozen) to encode UI screenshots into 256 visual
|
| 5 |
+
tokens projected from 1024 β 4096 to match the Qwen hidden dimension.
|
| 6 |
+
Output shape: (batch, 256, 4096).
|
| 7 |
"""
|
| 8 |
|
| 9 |
from __future__ import annotations
|
|
|
|
| 14 |
import torch
|
| 15 |
import torch.nn as nn
|
| 16 |
from PIL import Image
|
| 17 |
+
from transformers import CLIPImageProcessor, CLIPVisionModel
|
| 18 |
|
| 19 |
|
| 20 |
class VisionEncoder(nn.Module):
|
| 21 |
+
"""
|
| 22 |
+
CLIP ViT-L/14 vision encoder for MINDI 1.5.
|
| 23 |
+
|
| 24 |
+
Extracts ALL 256 patch tokens (excludes CLS) from CLIP and
|
| 25 |
+
projects them from 1024 β 4096 to match Qwen2.5 hidden_size.
|
| 26 |
+
The CLIP backbone is frozen; only the projection layer trains.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
NUM_PATCHES: int = 256 # ViT-L/14: 16Γ16 patches from 224Γ224
|
| 30 |
|
| 31 |
def __init__(
|
| 32 |
self,
|
| 33 |
model_name: str = "openai/clip-vit-large-patch14",
|
| 34 |
+
llm_hidden_size: int = 4096,
|
| 35 |
device: Optional[str] = None,
|
| 36 |
cache_dir: Optional[Path] = None,
|
| 37 |
+
torch_dtype: torch.dtype = torch.float32,
|
| 38 |
) -> None:
|
| 39 |
+
"""
|
| 40 |
+
Initialize the vision encoder.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
model_name: HuggingFace CLIP vision model identifier.
|
| 44 |
+
llm_hidden_size: Target projection dimension (must match LLM hidden_size).
|
| 45 |
+
device: Target device ('cuda', 'cpu', or None for auto).
|
| 46 |
+
cache_dir: Local directory for model weight cache.
|
| 47 |
+
torch_dtype: Data type for CLIP weights (projection always float32).
|
| 48 |
+
"""
|
| 49 |
super().__init__()
|
| 50 |
+
self.model_name = model_name
|
| 51 |
+
self.llm_hidden_size = llm_hidden_size
|
| 52 |
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 53 |
+
self.cache_dir = Path(cache_dir) if cache_dir else Path("./checkpoints/vision")
|
| 54 |
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
| 55 |
|
| 56 |
+
# Load CLIP vision model (no text tower) and image processor
|
| 57 |
+
print(f"[VisionEncoder] Loading {model_name} ...")
|
| 58 |
+
self.clip = CLIPVisionModel.from_pretrained(
|
| 59 |
+
model_name,
|
| 60 |
+
cache_dir=str(self.cache_dir),
|
| 61 |
+
torch_dtype=torch_dtype,
|
| 62 |
)
|
| 63 |
+
self.image_processor = CLIPImageProcessor.from_pretrained(
|
| 64 |
+
model_name,
|
| 65 |
+
cache_dir=str(self.cache_dir),
|
| 66 |
)
|
| 67 |
|
| 68 |
+
# Freeze entire CLIP backbone
|
| 69 |
for param in self.clip.parameters():
|
| 70 |
param.requires_grad = False
|
| 71 |
+
self.clip.eval()
|
| 72 |
|
| 73 |
+
# Trainable projection: CLIP hidden (1024) β LLM hidden (4096)
|
| 74 |
+
clip_hidden_size: int = self.clip.config.hidden_size # 1024
|
| 75 |
+
self.projection = nn.Linear(clip_hidden_size, self.llm_hidden_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
self.to(self.device)
|
| 78 |
|
| 79 |
+
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 80 |
+
total = sum(p.numel() for p in self.parameters())
|
| 81 |
+
print(f"[VisionEncoder] Loaded β {clip_hidden_size} β {self.llm_hidden_size}")
|
| 82 |
+
print(f" Trainable: {trainable:,} | Total: {total:,}")
|
| 83 |
+
|
| 84 |
+
def encode_image(self, image: Optional[Image.Image]) -> Optional[torch.Tensor]:
|
| 85 |
+
"""
|
| 86 |
+
Encode a single PIL image into projected patch token embeddings.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
image: A PIL Image (RGB), or None.
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
Tensor of shape (1, 256, 4096) or None if input is None.
|
| 93 |
+
"""
|
| 94 |
+
if image is None:
|
| 95 |
+
return None
|
| 96 |
+
|
| 97 |
+
inputs = self.image_processor(images=image, return_tensors="pt")
|
| 98 |
+
pixel_values = inputs["pixel_values"].to(device=self.device, dtype=self.clip.dtype)
|
| 99 |
|
| 100 |
with torch.no_grad():
|
| 101 |
+
vision_outputs = self.clip(pixel_values=pixel_values)
|
| 102 |
+
# last_hidden_state: (batch, 257, 1024) β 1 CLS + 256 patches
|
| 103 |
+
patch_tokens = vision_outputs.last_hidden_state[:, 1:, :] # (1, 256, 1024)
|
| 104 |
|
| 105 |
+
# Project into LLM embedding space (trainable)
|
| 106 |
+
projected = self.projection(patch_tokens.float()) # (1, 256, 4096)
|
| 107 |
return projected
|
| 108 |
|
| 109 |
+
def encode_batch(self, images: list[Optional[Image.Image]]) -> list[Optional[torch.Tensor]]:
|
| 110 |
+
"""
|
| 111 |
+
Encode a batch of images. None entries pass through as None.
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
images: List of PIL Images or Nones.
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
List of tensors (1, 256, 4096) or Nones matching input order.
|
| 118 |
+
"""
|
| 119 |
+
results: list[Optional[torch.Tensor]] = [None] * len(images)
|
| 120 |
+
valid_indices = [i for i, img in enumerate(images) if img is not None]
|
| 121 |
+
|
| 122 |
+
if not valid_indices:
|
| 123 |
+
return results
|
| 124 |
+
|
| 125 |
+
valid_images = [images[i] for i in valid_indices]
|
| 126 |
+
inputs = self.image_processor(images=valid_images, return_tensors="pt")
|
| 127 |
+
pixel_values = inputs["pixel_values"].to(device=self.device, dtype=self.clip.dtype)
|
| 128 |
|
| 129 |
+
with torch.no_grad():
|
| 130 |
+
vision_outputs = self.clip(pixel_values=pixel_values)
|
| 131 |
+
patch_tokens = vision_outputs.last_hidden_state[:, 1:, :] # (N, 256, 1024)
|
| 132 |
+
|
| 133 |
+
projected = self.projection(patch_tokens.float()) # (N, 256, 4096)
|
| 134 |
+
|
| 135 |
+
for batch_idx, orig_idx in enumerate(valid_indices):
|
| 136 |
+
results[orig_idx] = projected[batch_idx].unsqueeze(0) # (1, 256, 4096)
|
| 137 |
+
|
| 138 |
+
return results
|
| 139 |
+
|
| 140 |
+
def encode_screenshot(self, screenshot_path: Path) -> Optional[torch.Tensor]:
|
| 141 |
+
"""
|
| 142 |
+
Load a screenshot from disk and encode it.
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
screenshot_path: Path to image file.
|
| 146 |
+
|
| 147 |
+
Returns:
|
| 148 |
+
Tensor of shape (1, 256, 4096).
|
| 149 |
+
"""
|
| 150 |
+
path = Path(screenshot_path)
|
| 151 |
+
if not path.exists():
|
| 152 |
+
raise FileNotFoundError(f"Screenshot not found: {path}")
|
| 153 |
+
image = Image.open(path).convert("RGB")
|
| 154 |
return self.encode_image(image)
|
| 155 |
|
| 156 |
def save_projection(self, save_dir: Optional[Path] = None) -> Path:
|
| 157 |
+
"""
|
| 158 |
+
Save only the trainable projection weights.
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
save_dir: Directory to save to. Defaults to cache_dir/projection.
|
| 162 |
+
|
| 163 |
+
Returns:
|
| 164 |
+
Path where weights were saved.
|
| 165 |
+
"""
|
| 166 |
+
save_path = Path(save_dir) if save_dir else self.cache_dir / "projection"
|
| 167 |
save_path.mkdir(parents=True, exist_ok=True)
|
| 168 |
torch.save(self.projection.state_dict(), save_path / "projection.pt")
|
| 169 |
+
print(f"[VisionEncoder] Projection saved to {save_path}")
|
| 170 |
return save_path
|
| 171 |
|
| 172 |
def load_projection(self, load_dir: Path) -> None:
|
| 173 |
+
"""
|
| 174 |
+
Load projection weights from disk.
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
load_dir: Directory containing projection.pt.
|
| 178 |
+
"""
|
| 179 |
+
weights_path = Path(load_dir) / "projection.pt"
|
| 180 |
if not weights_path.exists():
|
| 181 |
raise FileNotFoundError(f"Projection weights not found: {weights_path}")
|
| 182 |
state_dict = torch.load(weights_path, map_location=self.device, weights_only=True)
|
| 183 |
self.projection.load_state_dict(state_dict)
|
| 184 |
+
print(f"[VisionEncoder] Projection loaded from {load_dir}")
|
| 185 |
+
|
| 186 |
+
def get_num_visual_tokens(self) -> int:
|
| 187 |
+
"""Return the number of visual tokens produced per image (256)."""
|
| 188 |
+
return self.NUM_PATCHES
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
# ββ Test block ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 192 |
+
if __name__ == "__main__":
|
| 193 |
+
print("=" * 60)
|
| 194 |
+
print(" MINDI 1.5 β Vision Encoder Test")
|
| 195 |
+
print("=" * 60)
|
| 196 |
+
print()
|
| 197 |
+
|
| 198 |
+
# 1. Initialize encoder
|
| 199 |
+
encoder = VisionEncoder(
|
| 200 |
+
model_name="openai/clip-vit-large-patch14",
|
| 201 |
+
llm_hidden_size=4096,
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
# 2. Create a dummy image (224Γ224 RGB)
|
| 205 |
+
dummy_image = Image.new("RGB", (224, 224), color=(128, 128, 128))
|
| 206 |
+
|
| 207 |
+
# 3. Encode single image
|
| 208 |
+
print("\n Encoding single image ...")
|
| 209 |
+
output = encoder.encode_image(dummy_image)
|
| 210 |
+
assert output is not None
|
| 211 |
+
print(f" Output shape: {output.shape}")
|
| 212 |
+
assert output.shape == (1, 256, 4096), f"Expected (1, 256, 4096), got {output.shape}"
|
| 213 |
+
|
| 214 |
+
# 4. Encode None β should return None
|
| 215 |
+
none_output = encoder.encode_image(None)
|
| 216 |
+
assert none_output is None, "Expected None for None input"
|
| 217 |
+
print(" None input β None output β")
|
| 218 |
+
|
| 219 |
+
# 5. Encode batch (mixed with None)
|
| 220 |
+
print("\n Encoding batch [image, None, image] ...")
|
| 221 |
+
batch_results = encoder.encode_batch([dummy_image, None, dummy_image])
|
| 222 |
+
assert batch_results[0] is not None and batch_results[0].shape == (1, 256, 4096)
|
| 223 |
+
assert batch_results[1] is None
|
| 224 |
+
assert batch_results[2] is not None and batch_results[2].shape == (1, 256, 4096)
|
| 225 |
+
print(f" Batch results: [{batch_results[0].shape}, None, {batch_results[2].shape}]")
|
| 226 |
+
|
| 227 |
+
# 6. Check trainable params (only projection should train)
|
| 228 |
+
trainable = sum(p.numel() for p in encoder.parameters() if p.requires_grad)
|
| 229 |
+
frozen = sum(p.numel() for p in encoder.parameters() if not p.requires_grad)
|
| 230 |
+
print(f"\n Trainable: {trainable:,}")
|
| 231 |
+
print(f" Frozen: {frozen:,}")
|
| 232 |
+
assert trainable == 1024 * 4096 + 4096, f"Unexpected trainable count: {trainable}"
|
| 233 |
+
assert frozen > trainable, "CLIP backbone should be frozen"
|
| 234 |
+
|
| 235 |
+
# 7. Save and reload projection
|
| 236 |
+
print("\n Testing save/load projection ...")
|
| 237 |
+
import tempfile
|
| 238 |
+
with tempfile.TemporaryDirectory() as tmp:
|
| 239 |
+
save_path = encoder.save_projection(Path(tmp))
|
| 240 |
+
old_weight = encoder.projection.weight.clone()
|
| 241 |
+
# Perturb weights
|
| 242 |
+
encoder.projection.weight.data.fill_(0.0)
|
| 243 |
+
assert not torch.equal(encoder.projection.weight, old_weight)
|
| 244 |
+
# Reload
|
| 245 |
+
encoder.load_projection(Path(tmp))
|
| 246 |
+
assert torch.equal(encoder.projection.weight, old_weight), "Weights not restored!"
|
| 247 |
+
print(" Save/load round-trip β")
|
| 248 |
+
|
| 249 |
+
print("\n β All vision encoder tests passed!")
|
| 250 |
+
print("=" * 60)
|
src/training/mindi_trainer.py
ADDED
|
@@ -0,0 +1,745 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MINDI 1.5 Vision-Coder β Trainer
|
| 3 |
+
|
| 4 |
+
Production-ready 3-phase training loop optimized for AMD MI300X (192GB VRAM).
|
| 5 |
+
Streams training data from disk (4.18GB train.jsonl) to avoid RAM exhaustion.
|
| 6 |
+
|
| 7 |
+
Phases:
|
| 8 |
+
Phase 1 (steps 0β5000): LoRA only, LR 2e-4, batch 16
|
| 9 |
+
Phase 2 (steps 5000β7500): Vision bridge only, LR 1e-5, batch 8
|
| 10 |
+
Phase 3 (steps 7500β10000): All trainable, LR 5e-5, batch 12
|
| 11 |
+
|
| 12 |
+
MI300X specifics:
|
| 13 |
+
- ROCm presents as CUDA to PyTorch (torch.cuda.* works)
|
| 14 |
+
- bf16 (NOT fp16) for AMD stability
|
| 15 |
+
- torch.compile() optional (works on ROCm)
|
| 16 |
+
- Gradient checkpointing enabled
|
| 17 |
+
- DataLoader: num_workers=4, pin_memory=True, prefetch_factor=2
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
from __future__ import annotations
|
| 21 |
+
|
| 22 |
+
import json
|
| 23 |
+
import math
|
| 24 |
+
import time
|
| 25 |
+
from dataclasses import dataclass, field
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
from typing import Any, Iterator, Optional
|
| 28 |
+
|
| 29 |
+
import torch
|
| 30 |
+
import torch.nn as nn
|
| 31 |
+
from torch.optim import AdamW
|
| 32 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
|
| 33 |
+
from torch.utils.data import DataLoader, IterableDataset
|
| 34 |
+
|
| 35 |
+
# ββ Configuration βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 36 |
+
|
| 37 |
+
PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@dataclass
|
| 41 |
+
class PhaseConfig:
|
| 42 |
+
"""Configuration for a single training phase."""
|
| 43 |
+
name: str
|
| 44 |
+
start_step: int
|
| 45 |
+
end_step: int
|
| 46 |
+
learning_rate: float
|
| 47 |
+
batch_size: int
|
| 48 |
+
gradient_accumulation_steps: int = 4
|
| 49 |
+
# Component toggles
|
| 50 |
+
lora: bool = False
|
| 51 |
+
vision_projection: bool = False
|
| 52 |
+
fusion: bool = False
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@dataclass
|
| 56 |
+
class TrainingConfig:
|
| 57 |
+
"""Full training configuration."""
|
| 58 |
+
|
| 59 |
+
# Data paths
|
| 60 |
+
train_file: Path = field(default_factory=lambda: PROJECT_ROOT / "data" / "processed" / "train.jsonl")
|
| 61 |
+
val_file: Path = field(default_factory=lambda: PROJECT_ROOT / "data" / "processed" / "val.jsonl")
|
| 62 |
+
|
| 63 |
+
# Output
|
| 64 |
+
output_dir: Path = field(default_factory=lambda: PROJECT_ROOT / "checkpoints" / "training")
|
| 65 |
+
log_dir: Path = field(default_factory=lambda: PROJECT_ROOT / "logs" / "training")
|
| 66 |
+
|
| 67 |
+
# Model
|
| 68 |
+
max_seq_length: int = 8192
|
| 69 |
+
use_compile: bool = False
|
| 70 |
+
gradient_checkpointing: bool = True
|
| 71 |
+
|
| 72 |
+
# Hardware (MI300X defaults)
|
| 73 |
+
dtype: str = "bf16"
|
| 74 |
+
num_workers: int = 4
|
| 75 |
+
pin_memory: bool = True
|
| 76 |
+
prefetch_factor: int = 2
|
| 77 |
+
|
| 78 |
+
# Training
|
| 79 |
+
weight_decay: float = 0.01
|
| 80 |
+
warmup_ratio: float = 0.03
|
| 81 |
+
max_grad_norm: float = 1.0
|
| 82 |
+
seed: int = 42
|
| 83 |
+
|
| 84 |
+
# Logging
|
| 85 |
+
log_every_n_steps: int = 10
|
| 86 |
+
eval_every_n_steps: int = 250
|
| 87 |
+
save_every_n_steps: int = 500
|
| 88 |
+
|
| 89 |
+
# Phases
|
| 90 |
+
phases: list[PhaseConfig] = field(default_factory=lambda: [
|
| 91 |
+
PhaseConfig(
|
| 92 |
+
name="phase1_lora",
|
| 93 |
+
start_step=0, end_step=5000,
|
| 94 |
+
learning_rate=2e-4, batch_size=16,
|
| 95 |
+
lora=True, vision_projection=False, fusion=False,
|
| 96 |
+
),
|
| 97 |
+
PhaseConfig(
|
| 98 |
+
name="phase2_vision_bridge",
|
| 99 |
+
start_step=5000, end_step=7500,
|
| 100 |
+
learning_rate=1e-5, batch_size=8,
|
| 101 |
+
lora=False, vision_projection=True, fusion=True,
|
| 102 |
+
),
|
| 103 |
+
PhaseConfig(
|
| 104 |
+
name="phase3_all",
|
| 105 |
+
start_step=7500, end_step=10000,
|
| 106 |
+
learning_rate=5e-5, batch_size=12,
|
| 107 |
+
lora=True, vision_projection=True, fusion=True,
|
| 108 |
+
),
|
| 109 |
+
])
|
| 110 |
+
|
| 111 |
+
@property
|
| 112 |
+
def total_steps(self) -> int:
|
| 113 |
+
return self.phases[-1].end_step if self.phases else 0
|
| 114 |
+
|
| 115 |
+
@property
|
| 116 |
+
def torch_dtype(self) -> torch.dtype:
|
| 117 |
+
return {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[self.dtype]
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
# ββ Streaming Dataset βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 121 |
+
|
| 122 |
+
class StreamingJSONLDataset(IterableDataset):
|
| 123 |
+
"""
|
| 124 |
+
Streams JSONL training data from disk line by line.
|
| 125 |
+
Tokenizes on-the-fly to avoid loading 4+ GB into RAM.
|
| 126 |
+
|
| 127 |
+
Expected JSONL format:
|
| 128 |
+
{"id": "...", "type": "...", "source": "...",
|
| 129 |
+
"messages": [{"role": "system", "content": "..."},
|
| 130 |
+
{"role": "user", "content": "..."},
|
| 131 |
+
{"role": "assistant", "content": "..."}],
|
| 132 |
+
"metadata": {...}}
|
| 133 |
+
"""
|
| 134 |
+
|
| 135 |
+
def __init__(
|
| 136 |
+
self,
|
| 137 |
+
file_path: Path,
|
| 138 |
+
tokenizer: Any,
|
| 139 |
+
max_length: int = 8192,
|
| 140 |
+
shuffle_buffer: int = 10000,
|
| 141 |
+
seed: int = 42,
|
| 142 |
+
) -> None:
|
| 143 |
+
self.file_path = Path(file_path)
|
| 144 |
+
self.tokenizer = tokenizer
|
| 145 |
+
self.max_length = max_length
|
| 146 |
+
self.shuffle_buffer = shuffle_buffer
|
| 147 |
+
self.seed = seed
|
| 148 |
+
|
| 149 |
+
if not self.file_path.exists():
|
| 150 |
+
raise FileNotFoundError(f"Training data not found: {self.file_path}")
|
| 151 |
+
|
| 152 |
+
def _format_messages(self, messages: list[dict[str, str]]) -> str:
|
| 153 |
+
"""Format chat messages into a single training string."""
|
| 154 |
+
# Use the tokenizer's chat template if available
|
| 155 |
+
if hasattr(self.tokenizer, "apply_chat_template"):
|
| 156 |
+
return self.tokenizer.apply_chat_template(
|
| 157 |
+
messages, tokenize=False, add_generation_prompt=False
|
| 158 |
+
)
|
| 159 |
+
# Fallback: simple concatenation
|
| 160 |
+
parts = []
|
| 161 |
+
for msg in messages:
|
| 162 |
+
role = msg.get("role", "user")
|
| 163 |
+
content = msg.get("content", "")
|
| 164 |
+
parts.append(f"<|{role}|>\n{content}")
|
| 165 |
+
return "\n".join(parts)
|
| 166 |
+
|
| 167 |
+
def _tokenize(self, text: str) -> Optional[dict[str, torch.Tensor]]:
|
| 168 |
+
"""Tokenize text and create training labels."""
|
| 169 |
+
encoded = self.tokenizer(
|
| 170 |
+
text,
|
| 171 |
+
max_length=self.max_length,
|
| 172 |
+
truncation=True,
|
| 173 |
+
padding="max_length",
|
| 174 |
+
return_tensors="pt",
|
| 175 |
+
)
|
| 176 |
+
input_ids = encoded["input_ids"].squeeze(0)
|
| 177 |
+
attention_mask = encoded["attention_mask"].squeeze(0)
|
| 178 |
+
|
| 179 |
+
# Labels = input_ids, with padding tokens masked as -100
|
| 180 |
+
labels = input_ids.clone()
|
| 181 |
+
labels[attention_mask == 0] = -100
|
| 182 |
+
|
| 183 |
+
return {
|
| 184 |
+
"input_ids": input_ids,
|
| 185 |
+
"attention_mask": attention_mask,
|
| 186 |
+
"labels": labels,
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
def _line_iterator(self) -> Iterator[dict]:
|
| 190 |
+
"""Iterate over JSONL file line by line."""
|
| 191 |
+
with open(self.file_path, "r", encoding="utf-8") as f:
|
| 192 |
+
for line in f:
|
| 193 |
+
line = line.strip()
|
| 194 |
+
if line:
|
| 195 |
+
yield json.loads(line)
|
| 196 |
+
|
| 197 |
+
def _shuffled_iterator(self) -> Iterator[dict]:
|
| 198 |
+
"""Reservoir-style shuffle buffer for streaming data."""
|
| 199 |
+
import random
|
| 200 |
+
rng = random.Random(self.seed)
|
| 201 |
+
buffer: list[dict] = []
|
| 202 |
+
|
| 203 |
+
for item in self._line_iterator():
|
| 204 |
+
buffer.append(item)
|
| 205 |
+
if len(buffer) >= self.shuffle_buffer:
|
| 206 |
+
rng.shuffle(buffer)
|
| 207 |
+
yield from buffer
|
| 208 |
+
buffer.clear()
|
| 209 |
+
|
| 210 |
+
# Flush remaining items
|
| 211 |
+
if buffer:
|
| 212 |
+
rng.shuffle(buffer)
|
| 213 |
+
yield from buffer
|
| 214 |
+
|
| 215 |
+
def __iter__(self) -> Iterator[dict[str, torch.Tensor]]:
|
| 216 |
+
for example in self._shuffled_iterator():
|
| 217 |
+
messages = example.get("messages", [])
|
| 218 |
+
if not messages:
|
| 219 |
+
continue
|
| 220 |
+
text = self._format_messages(messages)
|
| 221 |
+
tokenized = self._tokenize(text)
|
| 222 |
+
if tokenized is not None:
|
| 223 |
+
yield tokenized
|
| 224 |
+
|
| 225 |
+
def count_lines(self) -> int:
|
| 226 |
+
"""Count total lines (for progress estimation). Reads file once."""
|
| 227 |
+
count = 0
|
| 228 |
+
with open(self.file_path, "r", encoding="utf-8") as f:
|
| 229 |
+
for _ in f:
|
| 230 |
+
count += 1
|
| 231 |
+
return count
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
# ββ Trainer βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 235 |
+
|
| 236 |
+
class MINDITrainer:
|
| 237 |
+
"""
|
| 238 |
+
3-phase trainer for MINDI 1.5 Vision-Coder.
|
| 239 |
+
|
| 240 |
+
Optimized for AMD MI300X 192GB:
|
| 241 |
+
- bf16 mixed precision
|
| 242 |
+
- Gradient checkpointing
|
| 243 |
+
- Streaming data from disk
|
| 244 |
+
- Optional torch.compile()
|
| 245 |
+
- Phase-based component freezing/unfreezing
|
| 246 |
+
"""
|
| 247 |
+
|
| 248 |
+
def __init__(
|
| 249 |
+
self,
|
| 250 |
+
model: nn.Module,
|
| 251 |
+
config: TrainingConfig,
|
| 252 |
+
) -> None:
|
| 253 |
+
"""
|
| 254 |
+
Initialize the trainer.
|
| 255 |
+
|
| 256 |
+
Args:
|
| 257 |
+
model: MINDI15 model instance (already initialized).
|
| 258 |
+
config: Training configuration.
|
| 259 |
+
"""
|
| 260 |
+
self.model = model
|
| 261 |
+
self.config = config
|
| 262 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 263 |
+
self.global_step = 0
|
| 264 |
+
self.best_val_loss = float("inf")
|
| 265 |
+
|
| 266 |
+
# Create output directories
|
| 267 |
+
self.config.output_dir.mkdir(parents=True, exist_ok=True)
|
| 268 |
+
self.config.log_dir.mkdir(parents=True, exist_ok=True)
|
| 269 |
+
|
| 270 |
+
# Gradient checkpointing
|
| 271 |
+
if config.gradient_checkpointing:
|
| 272 |
+
base_model = self.model.architecture.get_model()
|
| 273 |
+
if hasattr(base_model, "gradient_checkpointing_enable"):
|
| 274 |
+
base_model.gradient_checkpointing_enable()
|
| 275 |
+
print("[MINDITrainer] Gradient checkpointing enabled")
|
| 276 |
+
|
| 277 |
+
# Optional torch.compile (works on ROCm)
|
| 278 |
+
if config.use_compile:
|
| 279 |
+
print("[MINDITrainer] Compiling model with torch.compile() ...")
|
| 280 |
+
self.model.architecture.peft_model = torch.compile(
|
| 281 |
+
self.model.architecture.peft_model
|
| 282 |
+
)
|
| 283 |
+
print("[MINDITrainer] Compilation complete")
|
| 284 |
+
|
| 285 |
+
# Mixed precision scaler (bf16 doesn't need GradScaler, but keep structure)
|
| 286 |
+
self.use_amp = config.dtype in ("bf16", "fp16")
|
| 287 |
+
self.amp_dtype = config.torch_dtype
|
| 288 |
+
|
| 289 |
+
# Training log
|
| 290 |
+
self.log_file = config.log_dir / "training_log.jsonl"
|
| 291 |
+
self.metrics_history: list[dict] = []
|
| 292 |
+
|
| 293 |
+
print(f"[MINDITrainer] Device: {self.device}")
|
| 294 |
+
print(f"[MINDITrainer] Dtype: {config.dtype}")
|
| 295 |
+
print(f"[MINDITrainer] Total steps: {config.total_steps}")
|
| 296 |
+
print(f"[MINDITrainer] Phases: {len(config.phases)}")
|
| 297 |
+
|
| 298 |
+
def _build_optimizer(self, phase: PhaseConfig) -> AdamW:
|
| 299 |
+
"""Build optimizer for the current phase (only trainable params)."""
|
| 300 |
+
params = [p for p in self.model.parameters() if p.requires_grad]
|
| 301 |
+
if not params:
|
| 302 |
+
raise RuntimeError(f"No trainable parameters in phase '{phase.name}'")
|
| 303 |
+
return AdamW(
|
| 304 |
+
params,
|
| 305 |
+
lr=phase.learning_rate,
|
| 306 |
+
weight_decay=self.config.weight_decay,
|
| 307 |
+
betas=(0.9, 0.95),
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
def _build_scheduler(
|
| 311 |
+
self, optimizer: AdamW, phase: PhaseConfig
|
| 312 |
+
) -> torch.optim.lr_scheduler.LRScheduler:
|
| 313 |
+
"""Build LR scheduler: linear warmup + cosine decay."""
|
| 314 |
+
phase_steps = phase.end_step - phase.start_step
|
| 315 |
+
warmup_steps = max(1, int(phase_steps * self.config.warmup_ratio))
|
| 316 |
+
decay_steps = max(1, phase_steps - warmup_steps)
|
| 317 |
+
|
| 318 |
+
warmup = LinearLR(
|
| 319 |
+
optimizer,
|
| 320 |
+
start_factor=0.01,
|
| 321 |
+
end_factor=1.0,
|
| 322 |
+
total_iters=warmup_steps,
|
| 323 |
+
)
|
| 324 |
+
cosine = CosineAnnealingLR(
|
| 325 |
+
optimizer,
|
| 326 |
+
T_max=decay_steps,
|
| 327 |
+
eta_min=phase.learning_rate * 0.1,
|
| 328 |
+
)
|
| 329 |
+
return SequentialLR(
|
| 330 |
+
optimizer,
|
| 331 |
+
schedulers=[warmup, cosine],
|
| 332 |
+
milestones=[warmup_steps],
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
def _build_dataloader(
|
| 336 |
+
self, file_path: Path, batch_size: int, shuffle_buffer: int = 10000
|
| 337 |
+
) -> DataLoader:
|
| 338 |
+
"""Build a streaming DataLoader."""
|
| 339 |
+
dataset = StreamingJSONLDataset(
|
| 340 |
+
file_path=file_path,
|
| 341 |
+
tokenizer=self.model.tokenizer,
|
| 342 |
+
max_length=self.config.max_seq_length,
|
| 343 |
+
shuffle_buffer=shuffle_buffer,
|
| 344 |
+
seed=self.config.seed,
|
| 345 |
+
)
|
| 346 |
+
return DataLoader(
|
| 347 |
+
dataset,
|
| 348 |
+
batch_size=batch_size,
|
| 349 |
+
num_workers=self.config.num_workers,
|
| 350 |
+
pin_memory=self.config.pin_memory,
|
| 351 |
+
prefetch_factor=self.config.prefetch_factor if self.config.num_workers > 0 else None,
|
| 352 |
+
drop_last=True,
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
def _log_metrics(self, metrics: dict) -> None:
|
| 356 |
+
"""Append metrics to log file and history."""
|
| 357 |
+
self.metrics_history.append(metrics)
|
| 358 |
+
with open(self.log_file, "a", encoding="utf-8") as f:
|
| 359 |
+
f.write(json.dumps(metrics) + "\n")
|
| 360 |
+
|
| 361 |
+
@torch.no_grad()
|
| 362 |
+
def evaluate(self, val_loader: DataLoader, max_batches: int = 50) -> float:
|
| 363 |
+
"""
|
| 364 |
+
Run validation and return average loss.
|
| 365 |
+
|
| 366 |
+
Args:
|
| 367 |
+
val_loader: Validation DataLoader.
|
| 368 |
+
max_batches: Maximum batches to evaluate (for speed).
|
| 369 |
+
|
| 370 |
+
Returns:
|
| 371 |
+
Average validation loss.
|
| 372 |
+
"""
|
| 373 |
+
self.model.eval()
|
| 374 |
+
total_loss = 0.0
|
| 375 |
+
count = 0
|
| 376 |
+
|
| 377 |
+
for batch_idx, batch in enumerate(val_loader):
|
| 378 |
+
if batch_idx >= max_batches:
|
| 379 |
+
break
|
| 380 |
+
|
| 381 |
+
input_ids = batch["input_ids"].to(self.device)
|
| 382 |
+
attention_mask = batch["attention_mask"].to(self.device)
|
| 383 |
+
labels = batch["labels"].to(self.device)
|
| 384 |
+
|
| 385 |
+
with torch.autocast(device_type="cuda", dtype=self.amp_dtype, enabled=self.use_amp):
|
| 386 |
+
result = self.model(
|
| 387 |
+
input_ids=input_ids,
|
| 388 |
+
attention_mask=attention_mask,
|
| 389 |
+
labels=labels,
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
if result["loss"] is not None:
|
| 393 |
+
total_loss += result["loss"].item()
|
| 394 |
+
count += 1
|
| 395 |
+
|
| 396 |
+
self.model.train()
|
| 397 |
+
return total_loss / max(count, 1)
|
| 398 |
+
|
| 399 |
+
def _save_checkpoint(self, phase_name: str, step: int, val_loss: float) -> Path:
|
| 400 |
+
"""Save a training checkpoint."""
|
| 401 |
+
ckpt_dir = self.config.output_dir / f"{phase_name}_step{step}"
|
| 402 |
+
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
| 403 |
+
|
| 404 |
+
# Save model weights
|
| 405 |
+
self.model.save(ckpt_dir)
|
| 406 |
+
|
| 407 |
+
# Save trainer state
|
| 408 |
+
state = {
|
| 409 |
+
"global_step": self.global_step,
|
| 410 |
+
"phase": phase_name,
|
| 411 |
+
"step_in_phase": step,
|
| 412 |
+
"val_loss": val_loss,
|
| 413 |
+
"best_val_loss": self.best_val_loss,
|
| 414 |
+
}
|
| 415 |
+
torch.save(state, ckpt_dir / "trainer_state.pt")
|
| 416 |
+
|
| 417 |
+
print(f"[MINDITrainer] Checkpoint saved: {ckpt_dir}")
|
| 418 |
+
return ckpt_dir
|
| 419 |
+
|
| 420 |
+
def train_phase(self, phase: PhaseConfig) -> dict:
|
| 421 |
+
"""
|
| 422 |
+
Execute a single training phase.
|
| 423 |
+
|
| 424 |
+
Args:
|
| 425 |
+
phase: Phase configuration.
|
| 426 |
+
|
| 427 |
+
Returns:
|
| 428 |
+
Dict with phase training metrics.
|
| 429 |
+
"""
|
| 430 |
+
print()
|
| 431 |
+
print("=" * 60)
|
| 432 |
+
print(f" Phase: {phase.name}")
|
| 433 |
+
print(f" Steps: {phase.start_step} β {phase.end_step}")
|
| 434 |
+
print(f" LR: {phase.learning_rate} | Batch: {phase.batch_size}")
|
| 435 |
+
print(f" Components: LoRA={phase.lora}, Vision={phase.vision_projection}, "
|
| 436 |
+
f"Fusion={phase.fusion}")
|
| 437 |
+
print("=" * 60)
|
| 438 |
+
|
| 439 |
+
# Set trainable components
|
| 440 |
+
self.model.set_trainable_components(
|
| 441 |
+
lora=phase.lora,
|
| 442 |
+
vision_projection=phase.vision_projection,
|
| 443 |
+
fusion=phase.fusion,
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
# Build optimizer and scheduler for this phase
|
| 447 |
+
optimizer = self._build_optimizer(phase)
|
| 448 |
+
scheduler = self._build_scheduler(optimizer, phase)
|
| 449 |
+
|
| 450 |
+
# Build data loaders
|
| 451 |
+
train_loader = self._build_dataloader(
|
| 452 |
+
self.config.train_file, phase.batch_size
|
| 453 |
+
)
|
| 454 |
+
val_loader = self._build_dataloader(
|
| 455 |
+
self.config.val_file, batch_size=max(phase.batch_size // 2, 1),
|
| 456 |
+
shuffle_buffer=1000,
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
self.model.train()
|
| 460 |
+
phase_steps = phase.end_step - phase.start_step
|
| 461 |
+
step_in_phase = 0
|
| 462 |
+
accum_loss = 0.0
|
| 463 |
+
accum_count = 0
|
| 464 |
+
phase_start_time = time.time()
|
| 465 |
+
|
| 466 |
+
train_iter = iter(train_loader)
|
| 467 |
+
|
| 468 |
+
while step_in_phase < phase_steps:
|
| 469 |
+
# Get next batch (restart iterator if exhausted = new epoch)
|
| 470 |
+
try:
|
| 471 |
+
batch = next(train_iter)
|
| 472 |
+
except StopIteration:
|
| 473 |
+
train_iter = iter(train_loader)
|
| 474 |
+
batch = next(train_iter)
|
| 475 |
+
|
| 476 |
+
input_ids = batch["input_ids"].to(self.device)
|
| 477 |
+
attention_mask = batch["attention_mask"].to(self.device)
|
| 478 |
+
labels = batch["labels"].to(self.device)
|
| 479 |
+
|
| 480 |
+
# Forward pass with mixed precision
|
| 481 |
+
with torch.autocast(device_type="cuda", dtype=self.amp_dtype, enabled=self.use_amp):
|
| 482 |
+
result = self.model(
|
| 483 |
+
input_ids=input_ids,
|
| 484 |
+
attention_mask=attention_mask,
|
| 485 |
+
labels=labels,
|
| 486 |
+
)
|
| 487 |
+
loss = result["loss"]
|
| 488 |
+
|
| 489 |
+
if loss is None:
|
| 490 |
+
continue
|
| 491 |
+
|
| 492 |
+
# Scale loss for gradient accumulation
|
| 493 |
+
loss = loss / phase.gradient_accumulation_steps
|
| 494 |
+
|
| 495 |
+
# Backward pass
|
| 496 |
+
loss.backward()
|
| 497 |
+
accum_loss += loss.item() * phase.gradient_accumulation_steps
|
| 498 |
+
accum_count += 1
|
| 499 |
+
|
| 500 |
+
# Optimizer step (every gradient_accumulation_steps)
|
| 501 |
+
if accum_count % phase.gradient_accumulation_steps == 0:
|
| 502 |
+
# Gradient clipping
|
| 503 |
+
torch.nn.utils.clip_grad_norm_(
|
| 504 |
+
[p for p in self.model.parameters() if p.requires_grad],
|
| 505 |
+
self.config.max_grad_norm,
|
| 506 |
+
)
|
| 507 |
+
optimizer.step()
|
| 508 |
+
scheduler.step()
|
| 509 |
+
optimizer.zero_grad()
|
| 510 |
+
|
| 511 |
+
step_in_phase += 1
|
| 512 |
+
self.global_step += 1
|
| 513 |
+
avg_loss = accum_loss / phase.gradient_accumulation_steps
|
| 514 |
+
accum_loss = 0.0
|
| 515 |
+
|
| 516 |
+
# Logging
|
| 517 |
+
if step_in_phase % self.config.log_every_n_steps == 0:
|
| 518 |
+
elapsed = time.time() - phase_start_time
|
| 519 |
+
steps_per_sec = step_in_phase / elapsed if elapsed > 0 else 0.0
|
| 520 |
+
eta_sec = (phase_steps - step_in_phase) / steps_per_sec if steps_per_sec > 0 else 0.0
|
| 521 |
+
|
| 522 |
+
metrics = {
|
| 523 |
+
"phase": phase.name,
|
| 524 |
+
"global_step": self.global_step,
|
| 525 |
+
"step_in_phase": step_in_phase,
|
| 526 |
+
"loss": round(avg_loss, 4),
|
| 527 |
+
"lr": optimizer.param_groups[0]["lr"],
|
| 528 |
+
"steps_per_sec": round(steps_per_sec, 3),
|
| 529 |
+
"eta_minutes": round(eta_sec / 60, 1),
|
| 530 |
+
"elapsed_minutes": round(elapsed / 60, 1),
|
| 531 |
+
}
|
| 532 |
+
self._log_metrics(metrics)
|
| 533 |
+
print(f" [{phase.name}] step {step_in_phase}/{phase_steps} | "
|
| 534 |
+
f"loss={avg_loss:.4f} | "
|
| 535 |
+
f"lr={optimizer.param_groups[0]['lr']:.2e} | "
|
| 536 |
+
f"speed={steps_per_sec:.2f} steps/s | "
|
| 537 |
+
f"ETA={eta_sec / 60:.1f}min")
|
| 538 |
+
|
| 539 |
+
# Evaluation
|
| 540 |
+
if step_in_phase % self.config.eval_every_n_steps == 0:
|
| 541 |
+
val_loss = self.evaluate(val_loader)
|
| 542 |
+
print(f" [{phase.name}] EVAL step {step_in_phase} | val_loss={val_loss:.4f}")
|
| 543 |
+
self._log_metrics({
|
| 544 |
+
"phase": phase.name,
|
| 545 |
+
"global_step": self.global_step,
|
| 546 |
+
"val_loss": round(val_loss, 4),
|
| 547 |
+
"type": "eval",
|
| 548 |
+
})
|
| 549 |
+
|
| 550 |
+
# Save best model
|
| 551 |
+
if val_loss < self.best_val_loss:
|
| 552 |
+
self.best_val_loss = val_loss
|
| 553 |
+
self._save_checkpoint(phase.name, step_in_phase, val_loss)
|
| 554 |
+
print(f" [{phase.name}] New best val_loss: {val_loss:.4f}")
|
| 555 |
+
|
| 556 |
+
# Periodic save
|
| 557 |
+
if step_in_phase % self.config.save_every_n_steps == 0:
|
| 558 |
+
self._save_checkpoint(phase.name, step_in_phase, self.best_val_loss)
|
| 559 |
+
|
| 560 |
+
# End-of-phase save
|
| 561 |
+
phase_elapsed = time.time() - phase_start_time
|
| 562 |
+
self._save_checkpoint(phase.name, step_in_phase, self.best_val_loss)
|
| 563 |
+
|
| 564 |
+
phase_summary = {
|
| 565 |
+
"phase": phase.name,
|
| 566 |
+
"total_steps": step_in_phase,
|
| 567 |
+
"elapsed_minutes": round(phase_elapsed / 60, 1),
|
| 568 |
+
"best_val_loss": round(self.best_val_loss, 4),
|
| 569 |
+
"type": "phase_complete",
|
| 570 |
+
}
|
| 571 |
+
self._log_metrics(phase_summary)
|
| 572 |
+
print(f"\n [{phase.name}] Complete β {step_in_phase} steps in "
|
| 573 |
+
f"{phase_elapsed / 60:.1f} min")
|
| 574 |
+
|
| 575 |
+
return phase_summary
|
| 576 |
+
|
| 577 |
+
def train(self) -> dict:
|
| 578 |
+
"""
|
| 579 |
+
Run all 3 training phases sequentially.
|
| 580 |
+
|
| 581 |
+
Returns:
|
| 582 |
+
Dict with complete training summary.
|
| 583 |
+
"""
|
| 584 |
+
print()
|
| 585 |
+
print("=" * 60)
|
| 586 |
+
print(" MINDI 1.5 β Training Start")
|
| 587 |
+
print(f" Total phases: {len(self.config.phases)}")
|
| 588 |
+
print(f" Total steps: {self.config.total_steps}")
|
| 589 |
+
print(f" Device: {self.device}")
|
| 590 |
+
print(f" Dtype: {self.config.dtype}")
|
| 591 |
+
print(f" Output: {self.config.output_dir}")
|
| 592 |
+
print("=" * 60)
|
| 593 |
+
|
| 594 |
+
torch.manual_seed(self.config.seed)
|
| 595 |
+
if torch.cuda.is_available():
|
| 596 |
+
torch.cuda.manual_seed_all(self.config.seed)
|
| 597 |
+
|
| 598 |
+
training_start = time.time()
|
| 599 |
+
phase_summaries = []
|
| 600 |
+
|
| 601 |
+
for phase in self.config.phases:
|
| 602 |
+
summary = self.train_phase(phase)
|
| 603 |
+
phase_summaries.append(summary)
|
| 604 |
+
|
| 605 |
+
total_elapsed = time.time() - training_start
|
| 606 |
+
|
| 607 |
+
# Final save
|
| 608 |
+
final_dir = self.config.output_dir / "final"
|
| 609 |
+
final_dir.mkdir(parents=True, exist_ok=True)
|
| 610 |
+
self.model.save(final_dir)
|
| 611 |
+
|
| 612 |
+
training_summary = {
|
| 613 |
+
"total_steps": self.global_step,
|
| 614 |
+
"total_minutes": round(total_elapsed / 60, 1),
|
| 615 |
+
"best_val_loss": round(self.best_val_loss, 4),
|
| 616 |
+
"phases": phase_summaries,
|
| 617 |
+
"type": "training_complete",
|
| 618 |
+
}
|
| 619 |
+
self._log_metrics(training_summary)
|
| 620 |
+
|
| 621 |
+
print()
|
| 622 |
+
print("=" * 60)
|
| 623 |
+
print(" MINDI 1.5 β Training Complete")
|
| 624 |
+
print(f" Total steps: {self.global_step}")
|
| 625 |
+
print(f" Total time: {total_elapsed / 60:.1f} minutes")
|
| 626 |
+
print(f" Best val loss: {self.best_val_loss:.4f}")
|
| 627 |
+
print(f" Final saved to: {final_dir}")
|
| 628 |
+
print("=" * 60)
|
| 629 |
+
|
| 630 |
+
return training_summary
|
| 631 |
+
|
| 632 |
+
def resume_from_checkpoint(self, checkpoint_dir: Path) -> None:
|
| 633 |
+
"""
|
| 634 |
+
Resume training from a checkpoint.
|
| 635 |
+
|
| 636 |
+
Args:
|
| 637 |
+
checkpoint_dir: Directory containing saved checkpoint.
|
| 638 |
+
"""
|
| 639 |
+
checkpoint_dir = Path(checkpoint_dir)
|
| 640 |
+
state_file = checkpoint_dir / "trainer_state.pt"
|
| 641 |
+
|
| 642 |
+
if not state_file.exists():
|
| 643 |
+
raise FileNotFoundError(f"Trainer state not found: {state_file}")
|
| 644 |
+
|
| 645 |
+
# Load model weights
|
| 646 |
+
self.model.load(checkpoint_dir)
|
| 647 |
+
|
| 648 |
+
# Load trainer state
|
| 649 |
+
state = torch.load(state_file, map_location=self.device, weights_only=True)
|
| 650 |
+
self.global_step = state["global_step"]
|
| 651 |
+
self.best_val_loss = state["best_val_loss"]
|
| 652 |
+
|
| 653 |
+
print(f"[MINDITrainer] Resumed from step {self.global_step} "
|
| 654 |
+
f"(val_loss={self.best_val_loss:.4f})")
|
| 655 |
+
|
| 656 |
+
|
| 657 |
+
# ββ Test block ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 658 |
+
if __name__ == "__main__":
|
| 659 |
+
print("=" * 60)
|
| 660 |
+
print(" MINDI 1.5 β Trainer Test")
|
| 661 |
+
print("=" * 60)
|
| 662 |
+
print()
|
| 663 |
+
|
| 664 |
+
# ββ Test 1: Config defaults ββββββββββββββββββββββββββββββββββ
|
| 665 |
+
print(" Test 1: TrainingConfig defaults")
|
| 666 |
+
config = TrainingConfig()
|
| 667 |
+
assert config.total_steps == 10000
|
| 668 |
+
assert config.dtype == "bf16"
|
| 669 |
+
assert config.torch_dtype == torch.bfloat16
|
| 670 |
+
assert len(config.phases) == 3
|
| 671 |
+
assert config.gradient_checkpointing is True
|
| 672 |
+
assert config.num_workers == 4
|
| 673 |
+
assert config.pin_memory is True
|
| 674 |
+
assert config.prefetch_factor == 2
|
| 675 |
+
print(f" Total steps: {config.total_steps}")
|
| 676 |
+
print(f" Dtype: {config.dtype}")
|
| 677 |
+
print(f" Phases: {[p.name for p in config.phases]}")
|
| 678 |
+
print(" β Config defaults correct")
|
| 679 |
+
|
| 680 |
+
# ββ Test 2: Phase configs ββββββββββββββββββββββββββββββββββββ
|
| 681 |
+
print("\n Test 2: Phase configurations")
|
| 682 |
+
p1, p2, p3 = config.phases
|
| 683 |
+
assert p1.name == "phase1_lora"
|
| 684 |
+
assert p1.batch_size == 16
|
| 685 |
+
assert p1.learning_rate == 2e-4
|
| 686 |
+
assert p1.lora is True and p1.vision_projection is False and p1.fusion is False
|
| 687 |
+
|
| 688 |
+
assert p2.name == "phase2_vision_bridge"
|
| 689 |
+
assert p2.batch_size == 8
|
| 690 |
+
assert p2.learning_rate == 1e-5
|
| 691 |
+
assert p2.lora is False and p2.vision_projection is True and p2.fusion is True
|
| 692 |
+
|
| 693 |
+
assert p3.name == "phase3_all"
|
| 694 |
+
assert p3.batch_size == 12
|
| 695 |
+
assert p3.learning_rate == 5e-5
|
| 696 |
+
assert p3.lora is True and p3.vision_projection is True and p3.fusion is True
|
| 697 |
+
print(" Phase 1: LoRA only, batch=16, lr=2e-4 β")
|
| 698 |
+
print(" Phase 2: Vision bridge, batch=8, lr=1e-5 β")
|
| 699 |
+
print(" Phase 3: All, batch=12, lr=5e-5 β")
|
| 700 |
+
|
| 701 |
+
# ββ Test 3: Streaming dataset (if data exists) βββββββββββββββ
|
| 702 |
+
print("\n Test 3: StreamingJSONLDataset")
|
| 703 |
+
train_path = config.train_file
|
| 704 |
+
if train_path.exists():
|
| 705 |
+
from transformers import AutoTokenizer
|
| 706 |
+
tok = AutoTokenizer.from_pretrained(
|
| 707 |
+
str(PROJECT_ROOT / "data" / "tokenizer" / "mindi_tokenizer"),
|
| 708 |
+
trust_remote_code=True,
|
| 709 |
+
)
|
| 710 |
+
dataset = StreamingJSONLDataset(
|
| 711 |
+
file_path=train_path,
|
| 712 |
+
tokenizer=tok,
|
| 713 |
+
max_length=512, # small for test
|
| 714 |
+
shuffle_buffer=100,
|
| 715 |
+
)
|
| 716 |
+
count = 0
|
| 717 |
+
for item in dataset:
|
| 718 |
+
assert "input_ids" in item
|
| 719 |
+
assert "attention_mask" in item
|
| 720 |
+
assert "labels" in item
|
| 721 |
+
assert item["input_ids"].shape[0] == 512
|
| 722 |
+
count += 1
|
| 723 |
+
if count >= 5:
|
| 724 |
+
break
|
| 725 |
+
print(f" Streamed {count} examples, shape={item['input_ids'].shape} β")
|
| 726 |
+
else:
|
| 727 |
+
print(f" [SKIP] Train file not found: {train_path}")
|
| 728 |
+
|
| 729 |
+
# ββ Test 4: PhaseConfig step ranges ββββββββββββββββββββββββββ
|
| 730 |
+
print("\n Test 4: Phase step continuity")
|
| 731 |
+
for i in range(1, len(config.phases)):
|
| 732 |
+
prev = config.phases[i - 1]
|
| 733 |
+
curr = config.phases[i]
|
| 734 |
+
assert prev.end_step == curr.start_step, \
|
| 735 |
+
f"Gap between {prev.name} and {curr.name}"
|
| 736 |
+
print(" All phases are contiguous β")
|
| 737 |
+
|
| 738 |
+
# ββ Test 5: Gradient accumulation ββββββββββββββββββββββββββββ
|
| 739 |
+
print("\n Test 5: Gradient accumulation steps")
|
| 740 |
+
for phase in config.phases:
|
| 741 |
+
assert phase.gradient_accumulation_steps == 4
|
| 742 |
+
print(" All phases: grad_accum=4 β")
|
| 743 |
+
|
| 744 |
+
print("\n β All trainer tests passed!")
|
| 745 |
+
print("=" * 60)
|