Faaz commited on
Commit
2ff5c54
Β·
1 Parent(s): 59c6c97

Day 3 COMPLETE: Full model architecture

Browse files

Files added:
- src/model/architecture.py (Qwen2.5-Coder-7B + LoRA)
- src/model/vision_encoder.py (CLIP ViT-L/14, 256 tokens)
- src/model/fusion_layer.py (Linear+LayerNorm prepend)
- src/model/mindi_model.py (MINDI15 complete model)
- src/training/mindi_trainer.py (3-phase MI300X trainer)
- scripts/train.py (master training script)
- configs/training_config.yaml (MI300X config)
- setup_mi300x.sh (MI300X setup script)
- scripts/quality_filter.py, split_data.py, data_stats.py
- scripts/upload_everything_to_hf.py

Model: Qwen2.5-Coder-7B + LoRA + CLIP Vision
Ready for MI300X training!

configs/training_config.yaml CHANGED
@@ -1,57 +1,107 @@
1
  # ==========================================
2
  # MINDI 1.5 Vision-Coder β€” Training Configuration
 
3
  # ==========================================
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  training:
6
- # Hardware targets
7
- local_device: "cuda" # RTX 4060 8GB β€” for dev/testing only
8
- cloud_device: "cuda" # MI300X 192GB β€” for actual training
9
- precision: "bf16"
10
-
11
- # Hyperparameters
12
- epochs: 3
13
- batch_size: 4
14
- gradient_accumulation_steps: 8
15
- effective_batch_size: 32 # batch_size * grad_accum
16
- learning_rate: 2.0e-4
17
- weight_decay: 0.01
18
- warmup_ratio: 0.03
19
- lr_scheduler: "cosine"
 
 
 
 
 
 
 
 
 
 
 
 
20
  max_grad_norm: 1.0
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- # Sequence settings
23
- max_seq_length: 8192
24
- packing: true # Pack short examples together
25
-
26
- # Checkpointing
27
- save_strategy: "steps"
28
- save_steps: 500
29
- save_total_limit: 5
30
- checkpoint_dir: "./checkpoints"
31
- resume_from_checkpoint: null
32
-
33
- # Logging
34
- logging_steps: 10
35
- log_dir: "./logs/training"
36
- report_to: "wandb"
37
-
38
- # Evaluation
39
- eval_strategy: "steps"
40
- eval_steps: 250
41
- eval_samples: 1000
42
-
43
- # Memory optimization (for RTX 4060 local testing)
44
- local_overrides:
45
- batch_size: 1
46
- gradient_accumulation_steps: 16
47
- max_seq_length: 2048
48
- gradient_checkpointing: true
49
- optim: "adamw_8bit"
50
-
51
- wandb:
52
- project: "mindi-1.5-vision-coder"
53
- entity: "mindigenous"
54
  tags:
55
  - "mindi-1.5"
56
  - "lora"
57
  - "vision-coder"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # ==========================================
2
  # MINDI 1.5 Vision-Coder β€” Training Configuration
3
+ # Optimized for AMD MI300X 192GB VRAM
4
  # ==========================================
5
 
6
+ # ── Model ──────────────────────────────────────────────────────
7
+ model:
8
+ name: "Qwen/Qwen2.5-Coder-7B-Instruct"
9
+ hidden_size: 4096
10
+ dtype: "bf16" # bf16 required for MI300X stability (NOT fp16)
11
+ use_compile: true # torch.compile() works on ROCm
12
+ gradient_checkpointing: true # Save VRAM even with 192GB
13
+
14
+ # ── LoRA ───────────────────────────────────────────────────────
15
+ lora:
16
+ r: 64
17
+ alpha: 128
18
+ dropout: 0.05
19
+ bias: "none"
20
+ task_type: "CAUSAL_LM"
21
+ target_modules:
22
+ - q_proj
23
+ - k_proj
24
+ - v_proj
25
+ - o_proj
26
+ - gate_proj
27
+ - up_proj
28
+ - down_proj
29
+
30
+ # ── Vision ─────────────────────────────────────────────────────
31
+ vision:
32
+ clip_model: "openai/clip-vit-large-patch14"
33
+ visual_tokens: 256 # 16Γ—16 patches from ViT-L/14
34
+ projection_size: 4096 # Must match model.hidden_size
35
+ freeze_clip: true # Freeze CLIP backbone
36
+
37
+ # ── Training Phases ────────────────────────────────────────────
38
  training:
39
+ # Phase 1: LoRA only β€” teach coding patterns
40
+ phase1:
41
+ steps: 5000
42
+ lr: 2.0e-4
43
+ batch_size: 16 # MI300X can handle large batches
44
+ warmup_steps: 100
45
+ data_filter: "code_only"
46
+
47
+ # Phase 2: Vision bridge only β€” align visual tokens
48
+ phase2:
49
+ steps: 2500
50
+ lr: 1.0e-5
51
+ batch_size: 8 # Smaller batch for vision bridge
52
+ warmup_steps: 50
53
+ data_filter: "websight_only"
54
+
55
+ # Phase 3: All trainable β€” joint fine-tuning
56
+ phase3:
57
+ steps: 2500
58
+ lr: 5.0e-5
59
+ batch_size: 12
60
+ warmup_steps: 50
61
+ data_filter: "all"
62
+
63
+ # Shared training settings
64
+ grad_accumulation: 4
65
  max_grad_norm: 1.0
66
+ eval_every: 250
67
+ save_every: 500
68
+
69
+ # ── Data ───────────────────────────────────────────────────────
70
+ data:
71
+ train_file: "data/processed/train.jsonl" # 4.18GB, 1,304,486 examples
72
+ val_file: "data/processed/val.jsonl" # 0.23GB, 72,471 examples
73
+ max_length: 4096
74
+ shuffle_buffer: 10000 # Streaming shuffle buffer size
75
+ num_workers: 4 # DataLoader workers
76
+ pin_memory: true
77
+ prefetch_factor: 2
78
 
79
+ # ── Logging ────────────────────────────────────────────────────
80
+ logging:
81
+ wandb_project: "mindi-1.5-vision-coder"
82
+ wandb_entity: "mindigenous"
83
+ log_every: 10 # Log metrics every N steps
84
+ log_dir: "logs/training"
85
+ sample_every: 500 # Generate sample outputs every N steps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  tags:
87
  - "mindi-1.5"
88
  - "lora"
89
  - "vision-coder"
90
+ - "mi300x"
91
+
92
+ # ── Output ─────────────────────────────────────────────────────
93
+ output:
94
+ checkpoint_dir: "checkpoints/training"
95
+ best_model: "checkpoints/best"
96
+ hf_repo: "Mindigenous/MINDI-1.5-Vision-Coder"
97
+ push_every_phase: true
98
+
99
+ # ── Local Dev Overrides (RTX 4060 8GB) ────────────────────────
100
+ # Apply these when testing locally with --dry_run
101
+ local_overrides:
102
+ batch_size: 1
103
+ gradient_accumulation_steps: 16
104
+ max_length: 2048
105
+ gradient_checkpointing: true
106
+ use_compile: false
107
+ num_workers: 0
data/processed/filter_report.json ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "input_count": 1481497,
3
+ "kept_count": 1449428,
4
+ "rejected_count": 32069,
5
+ "kept_pct": 97.84,
6
+ "avg_tokens": 593.1,
7
+ "avg_quality": 6.487,
8
+ "total_tokens": 859694776,
9
+ "elapsed_seconds": 1023.6,
10
+ "filter_settings": {
11
+ "min_tokens": 50,
12
+ "max_tokens": 4096,
13
+ "min_quality": 5.0
14
+ },
15
+ "rejection_breakdown": {
16
+ "too_many_tokens": 30637,
17
+ "boilerplate_content": 1373,
18
+ "duplicate_content": 59
19
+ },
20
+ "source_kept": {
21
+ "codealpaca": 59241,
22
+ "codefeedback": 149865,
23
+ "websight": 250987,
24
+ "synthetic_nextjs": 90000,
25
+ "search_examples": 15000,
26
+ "sandbox_examples": 9000,
27
+ "starcoderdata": 569350,
28
+ "evol_code": 155998,
29
+ "magicoder": 149987
30
+ },
31
+ "source_rejected": {
32
+ "codealpaca": 741,
33
+ "codefeedback": 132,
34
+ "starcoderdata": 30650,
35
+ "evol_code": 518,
36
+ "magicoder": 13,
37
+ "websight": 15
38
+ },
39
+ "type_distribution": {
40
+ "code_generation": 1183441,
41
+ "vision_code": 250987,
42
+ "search": 15000
43
+ },
44
+ "language_distribution": {
45
+ "unknown": 490305,
46
+ "typescript": 375859,
47
+ "javascript": 298497,
48
+ "python": 211842,
49
+ "html": 36371,
50
+ "java": 32458,
51
+ "rust": 3709,
52
+ "go": 387
53
+ }
54
+ }
data/processed/split_meta.json ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "total": 1449428,
3
+ "train_count": 1304486,
4
+ "val_count": 72471,
5
+ "test_count": 72471,
6
+ "train_pct": 90.0,
7
+ "val_pct": 5.0,
8
+ "test_pct": 5.0,
9
+ "seed": 42,
10
+ "source_breakdown": {
11
+ "codealpaca": {
12
+ "total": 59241,
13
+ "train": 53317,
14
+ "val": 2962,
15
+ "test": 2962
16
+ },
17
+ "codefeedback": {
18
+ "total": 149865,
19
+ "train": 134879,
20
+ "val": 7493,
21
+ "test": 7493
22
+ },
23
+ "evol_code": {
24
+ "total": 155998,
25
+ "train": 140398,
26
+ "val": 7800,
27
+ "test": 7800
28
+ },
29
+ "magicoder": {
30
+ "total": 149987,
31
+ "train": 134989,
32
+ "val": 7499,
33
+ "test": 7499
34
+ },
35
+ "sandbox_examples": {
36
+ "total": 9000,
37
+ "train": 8100,
38
+ "val": 450,
39
+ "test": 450
40
+ },
41
+ "search_examples": {
42
+ "total": 15000,
43
+ "train": 13500,
44
+ "val": 750,
45
+ "test": 750
46
+ },
47
+ "starcoderdata": {
48
+ "total": 569350,
49
+ "train": 512414,
50
+ "val": 28468,
51
+ "test": 28468
52
+ },
53
+ "synthetic_nextjs": {
54
+ "total": 90000,
55
+ "train": 81000,
56
+ "val": 4500,
57
+ "test": 4500
58
+ },
59
+ "websight": {
60
+ "total": 250987,
61
+ "train": 225889,
62
+ "val": 12549,
63
+ "test": 12549
64
+ }
65
+ }
66
+ }
scripts/data_stats.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ MINDI 1.5 Vision-Coder β€” Dataset Statistics Report
4
+
5
+ Generates comprehensive statistics for the final train/val/test splits:
6
+ - Total counts and sizes
7
+ - Token distribution (min, max, mean, median, p95, p99)
8
+ - Quality score distribution
9
+ - Source breakdown
10
+ - Type breakdown
11
+ - Language breakdown
12
+ - Special token usage
13
+
14
+ Usage:
15
+ python scripts/data_stats.py # Full report
16
+ python scripts/data_stats.py --split train # Stats for train only
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ import argparse
22
+ import json
23
+ import statistics
24
+ import sys
25
+ import time
26
+ from collections import Counter
27
+ from pathlib import Path
28
+
29
+ # ── Paths ─────────────────────────────────────────────────────────────
30
+
31
+ PROJECT_ROOT = Path(__file__).resolve().parent.parent
32
+ PROCESSED_DIR = PROJECT_ROOT / "data" / "processed"
33
+
34
+ SPLIT_FILES = {
35
+ "train": PROCESSED_DIR / "train.jsonl",
36
+ "val": PROCESSED_DIR / "val.jsonl",
37
+ "test": PROCESSED_DIR / "test.jsonl",
38
+ }
39
+
40
+ REPORT_FILE = PROCESSED_DIR / "dataset_stats.json"
41
+
42
+ # ── Special tokens to check ──────────────────────────────────────────
43
+
44
+ SPECIAL_TOKENS = [
45
+ "<|think_start|>", "<|think_end|>",
46
+ "<|code_start|>", "<|code_end|>",
47
+ "<|critique_start|>", "<|critique_end|>",
48
+ "<|suggest_start|>", "<|suggest_end|>",
49
+ "<|file_start|>", "<|file_end|>",
50
+ "<|search_start|>", "<|search_end|>",
51
+ "<|sandbox_start|>", "<|sandbox_end|>",
52
+ "<|vision_start|>", "<|vision_end|>",
53
+ "<|error_start|>", "<|error_end|>",
54
+ "<|fix_start|>", "<|fix_end|>",
55
+ ]
56
+
57
+
58
+ def percentile(sorted_data: list[int | float], p: float) -> float:
59
+ """Calculate the p-th percentile from sorted data."""
60
+ if not sorted_data:
61
+ return 0.0
62
+ k = (len(sorted_data) - 1) * (p / 100.0)
63
+ f = int(k)
64
+ c = f + 1
65
+ if c >= len(sorted_data):
66
+ return float(sorted_data[f])
67
+ return sorted_data[f] + (k - f) * (sorted_data[c] - sorted_data[f])
68
+
69
+
70
+ def compute_stats(file_path: Path, split_name: str) -> dict:
71
+ """Compute statistics for a single split file."""
72
+
73
+ if not file_path.exists():
74
+ return {"error": f"File not found: {file_path}"}
75
+
76
+ tokens_list: list[int] = []
77
+ quality_list: list[float] = []
78
+ source_counts: Counter = Counter()
79
+ type_counts: Counter = Counter()
80
+ lang_counts: Counter = Counter()
81
+ framework_counts: Counter = Counter()
82
+ has_vision_count = 0
83
+ special_token_counts: Counter = Counter()
84
+ msg_count_dist: Counter = Counter() # number of messages per example
85
+ total_chars = 0
86
+
87
+ count = 0
88
+ with open(file_path, "r", encoding="utf-8") as f:
89
+ for line in f:
90
+ line = line.strip()
91
+ if not line:
92
+ continue
93
+ try:
94
+ ex = json.loads(line)
95
+ except json.JSONDecodeError:
96
+ continue
97
+
98
+ count += 1
99
+ meta = ex.get("metadata", {})
100
+
101
+ # Token count
102
+ tokens = meta.get("tokens", 0)
103
+ tokens_list.append(tokens)
104
+
105
+ # Quality score
106
+ quality = meta.get("quality_score", 0.0)
107
+ quality_list.append(quality)
108
+
109
+ # Source, type, language, framework
110
+ source_counts[ex.get("source", "unknown")] += 1
111
+ type_counts[ex.get("type", "unknown")] += 1
112
+ lang_counts[meta.get("language", "unknown")] += 1
113
+ framework_counts[meta.get("framework", "none")] += 1
114
+
115
+ # Vision
116
+ if meta.get("has_vision", False):
117
+ has_vision_count += 1
118
+
119
+ # Messages
120
+ messages = ex.get("messages", [])
121
+ msg_count_dist[len(messages)] += 1
122
+
123
+ # Special tokens in assistant content
124
+ for msg in messages:
125
+ if msg.get("role") == "assistant":
126
+ content = msg.get("content", "")
127
+ total_chars += len(content)
128
+ for tok in SPECIAL_TOKENS:
129
+ if tok in content:
130
+ special_token_counts[tok] += 1
131
+
132
+ # Sort for percentile computation
133
+ tokens_sorted = sorted(tokens_list)
134
+ quality_sorted = sorted(quality_list)
135
+
136
+ file_size_mb = file_path.stat().st_size / (1024 * 1024)
137
+
138
+ stats = {
139
+ "split": split_name,
140
+ "file": file_path.name,
141
+ "file_size_mb": round(file_size_mb, 1),
142
+ "count": count,
143
+ "total_tokens": sum(tokens_list),
144
+ "total_chars_assistant": total_chars,
145
+ "has_vision": has_vision_count,
146
+ "tokens": {
147
+ "min": min(tokens_sorted) if tokens_sorted else 0,
148
+ "max": max(tokens_sorted) if tokens_sorted else 0,
149
+ "mean": round(statistics.mean(tokens_list), 1) if tokens_list else 0,
150
+ "median": round(statistics.median(tokens_list), 1) if tokens_list else 0,
151
+ "stdev": round(statistics.stdev(tokens_list), 1) if len(tokens_list) > 1 else 0,
152
+ "p5": round(percentile(tokens_sorted, 5), 1),
153
+ "p25": round(percentile(tokens_sorted, 25), 1),
154
+ "p75": round(percentile(tokens_sorted, 75), 1),
155
+ "p95": round(percentile(tokens_sorted, 95), 1),
156
+ "p99": round(percentile(tokens_sorted, 99), 1),
157
+ },
158
+ "quality_score": {
159
+ "min": round(min(quality_sorted), 2) if quality_sorted else 0,
160
+ "max": round(max(quality_sorted), 2) if quality_sorted else 0,
161
+ "mean": round(statistics.mean(quality_list), 2) if quality_list else 0,
162
+ "median": round(statistics.median(quality_list), 2) if quality_list else 0,
163
+ },
164
+ "source_distribution": dict(source_counts.most_common()),
165
+ "type_distribution": dict(type_counts.most_common()),
166
+ "language_distribution": dict(lang_counts.most_common(30)),
167
+ "framework_distribution": dict(framework_counts.most_common(15)),
168
+ "messages_per_example": dict(sorted(msg_count_dist.items())),
169
+ "special_token_usage": dict(special_token_counts.most_common()),
170
+ }
171
+
172
+ return stats
173
+
174
+
175
+ def print_stats(stats: dict) -> None:
176
+ """Pretty-print statistics for a split."""
177
+ if "error" in stats:
178
+ print(f" ERROR: {stats['error']}")
179
+ return
180
+
181
+ print(f" Split: {stats['split']}")
182
+ print(f" File: {stats['file']} ({stats['file_size_mb']:.1f} MB)")
183
+ print(f" Count: {stats['count']:,}")
184
+ print(f" Total tokens: {stats['total_tokens']:,}")
185
+ print(f" Vision examples: {stats['has_vision']:,}")
186
+ print()
187
+
188
+ t = stats["tokens"]
189
+ print(f" Token distribution:")
190
+ print(f" Min: {t['min']:>8,} P5: {t['p5']:>8,.0f}")
191
+ print(f" P25: {t['p25']:>8,.0f} Median: {t['median']:>8,.0f}")
192
+ print(f" Mean: {t['mean']:>8,.0f} P75: {t['p75']:>8,.0f}")
193
+ print(f" P95: {t['p95']:>8,.0f} P99: {t['p99']:>8,.0f}")
194
+ print(f" Max: {t['max']:>8,} Stdev: {t['stdev']:>8,.0f}")
195
+ print()
196
+
197
+ q = stats["quality_score"]
198
+ print(f" Quality score: min={q['min']:.1f} mean={q['mean']:.1f} median={q['median']:.1f} max={q['max']:.1f}")
199
+ print()
200
+
201
+ print(f" Source distribution:")
202
+ for src, cnt in stats["source_distribution"].items():
203
+ pct = cnt / stats["count"] * 100
204
+ print(f" {src:<25s} {cnt:>10,} ({pct:5.1f}%)")
205
+ print()
206
+
207
+ print(f" Type distribution:")
208
+ for t_name, cnt in list(stats["type_distribution"].items())[:10]:
209
+ pct = cnt / stats["count"] * 100
210
+ print(f" {t_name:<25s} {cnt:>10,} ({pct:5.1f}%)")
211
+ print()
212
+
213
+ print(f" Language distribution (top 15):")
214
+ for lang, cnt in list(stats["language_distribution"].items())[:15]:
215
+ pct = cnt / stats["count"] * 100
216
+ print(f" {lang:<25s} {cnt:>10,} ({pct:5.1f}%)")
217
+ print()
218
+
219
+ if stats["special_token_usage"]:
220
+ print(f" Special token usage (examples containing token):")
221
+ for tok, cnt in stats["special_token_usage"].items():
222
+ pct = cnt / stats["count"] * 100
223
+ print(f" {tok:<25s} {cnt:>10,} ({pct:5.1f}%)")
224
+ print()
225
+
226
+
227
+ def run_stats(split: str | None = None) -> None:
228
+ """Generate and display statistics."""
229
+ start = time.time()
230
+
231
+ if split:
232
+ files = {split: SPLIT_FILES.get(split)}
233
+ if files[split] is None:
234
+ print(f"ERROR: Unknown split '{split}'. Choose from: {list(SPLIT_FILES.keys())}")
235
+ sys.exit(1)
236
+ else:
237
+ files = SPLIT_FILES
238
+
239
+ all_stats = {}
240
+
241
+ for name, path in files.items():
242
+ print("=" * 60)
243
+ print(f" Computing stats for: {name}")
244
+ print("=" * 60)
245
+ stats = compute_stats(path, name)
246
+ all_stats[name] = stats
247
+ print_stats(stats)
248
+
249
+ # Save JSON report
250
+ REPORT_FILE.parent.mkdir(parents=True, exist_ok=True)
251
+ with open(REPORT_FILE, "w", encoding="utf-8") as f:
252
+ json.dump(all_stats, f, indent=2)
253
+ print(f"Full report saved to: {REPORT_FILE.name}")
254
+
255
+ elapsed = time.time() - start
256
+ print(f"Stats generated in {elapsed:.1f}s")
257
+
258
+
259
+ # ── CLI ───────────────────────────────────────────────────────────────
260
+
261
+ def main():
262
+ parser = argparse.ArgumentParser(
263
+ description="MINDI Dataset Statistics β€” comprehensive split analysis",
264
+ )
265
+ parser.add_argument("--split", type=str, choices=["train", "val", "test"],
266
+ help="Compute stats for a single split only")
267
+
268
+ args = parser.parse_args()
269
+ run_stats(split=args.split)
270
+
271
+
272
+ if __name__ == "__main__":
273
+ main()
scripts/quality_filter.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ MINDI 1.5 Vision-Coder β€” Quality Filter Pipeline
4
+
5
+ Filters mindi_all.jsonl to remove low-quality examples:
6
+ 1. Token length filter β€” drop if <50 tokens or >4096 tokens
7
+ 2. Duplicate detection β€” SHA-256 hash of assistant content
8
+ 3. JSON structure check β€” valid schema with required fields
9
+ 4. Special token check β€” assistant must have code_start/code_end pair
10
+ 5. Quality score filter β€” keep only quality_score >= 5.0
11
+ 6. Content heuristics β€” drop empty/trivial/boilerplate responses
12
+
13
+ Usage:
14
+ python scripts/quality_filter.py # Full run
15
+ python scripts/quality_filter.py --dry-run # Preview only
16
+ python scripts/quality_filter.py --min-tokens 100 # Custom min tokens
17
+ python scripts/quality_filter.py --max-tokens 8192 # Custom max tokens
18
+ python scripts/quality_filter.py --min-quality 7.0 # Stricter quality
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ import argparse
24
+ import hashlib
25
+ import json
26
+ import sys
27
+ import time
28
+ from collections import Counter, defaultdict
29
+ from pathlib import Path
30
+
31
+ # ── Paths ─────────────────────────────────────────────────────────────
32
+
33
+ PROJECT_ROOT = Path(__file__).resolve().parent.parent
34
+ INPUT_FILE = PROJECT_ROOT / "data" / "processed" / "mindi_all.jsonl"
35
+ OUTPUT_FILE = PROJECT_ROOT / "data" / "processed" / "mindi_filtered.jsonl"
36
+ REJECT_FILE = PROJECT_ROOT / "data" / "processed" / "mindi_rejected.jsonl"
37
+ REPORT_FILE = PROJECT_ROOT / "data" / "processed" / "filter_report.json"
38
+
39
+ # ── Required schema fields ────────────────────────────────────────────
40
+
41
+ REQUIRED_FIELDS = {"id", "type", "source", "messages", "metadata"}
42
+ REQUIRED_METADATA = {"language", "tokens"}
43
+ VALID_ROLES = {"system", "user", "assistant"}
44
+
45
+ # ── Protected sources (hand-crafted gold data β€” lighter filtering) ─────
46
+
47
+ PROTECTED_SOURCES = {"sandbox_examples", "search_examples", "synthetic_nextjs"}
48
+
49
+ # ── MINDI agentic token scoring bonuses ───────────────────────────────
50
+ # Examples with these tokens teach the model to be an *agent*.
51
+ # Each occurrence adds to the quality_score before the threshold.
52
+
53
+ MINDI_TOKEN_BONUSES = {
54
+ "<|think_start|>": 2.0,
55
+ "<|search_start|>": 3.0,
56
+ "<|error_start|>": 3.0,
57
+ "<|sandbox_start|>": 3.0,
58
+ "<|critique_start|>": 2.0,
59
+ "<|suggest_start|>": 1.0,
60
+ }
61
+
62
+ # ── Special token pairs that assistant messages should contain ─────────
63
+
64
+ CODE_TOKEN_PAIRS = [
65
+ ("<|code_start|>", "<|code_end|>"),
66
+ ]
67
+
68
+ # At least one of these pairs should be present in assistant content
69
+ OPTIONAL_TOKEN_PAIRS = [
70
+ ("<|think_start|>", "<|think_end|>"),
71
+ ("<|critique_start|>", "<|critique_end|>"),
72
+ ("<|suggest_start|>", "<|suggest_end|>"),
73
+ ("<|file_start|>", "<|file_end|>"),
74
+ ("<|search_start|>", "<|search_end|>"),
75
+ ("<|sandbox_start|>", "<|sandbox_end|>"),
76
+ ("<|error_start|>", "<|error_end|>"),
77
+ ("<|fix_start|>", "<|fix_end|>"),
78
+ ]
79
+
80
+ # ── Rejection reasons ─────────────────────────────────────────────────
81
+
82
+ class Reason:
83
+ INVALID_JSON = "invalid_json"
84
+ MISSING_FIELDS = "missing_fields"
85
+ MISSING_METADATA = "missing_metadata"
86
+ NO_MESSAGES = "no_messages"
87
+ BAD_ROLES = "bad_message_roles"
88
+ NO_ASSISTANT = "no_assistant_message"
89
+ EMPTY_ASSISTANT = "empty_assistant_content"
90
+ TOO_SHORT = "too_few_tokens"
91
+ TOO_LONG = "too_many_tokens"
92
+ DUPLICATE = "duplicate_content"
93
+ LOW_QUALITY = "low_quality_score"
94
+ NO_CODE_TOKENS = "missing_code_tokens"
95
+ BOILERPLATE = "boilerplate_content"
96
+ UNMATCHED_TOKENS = "unmatched_special_tokens"
97
+
98
+
99
+ # ── Filter functions ──────────────────────────────────────────────────
100
+
101
+ def validate_schema(example: dict) -> str | None:
102
+ """Check required fields and structure. Returns rejection reason or None."""
103
+ # Top-level fields
104
+ missing = REQUIRED_FIELDS - set(example.keys())
105
+ if missing:
106
+ return Reason.MISSING_FIELDS
107
+
108
+ # Metadata fields
109
+ meta = example.get("metadata", {})
110
+ if not isinstance(meta, dict):
111
+ return Reason.MISSING_METADATA
112
+ missing_meta = REQUIRED_METADATA - set(meta.keys())
113
+ if missing_meta:
114
+ return Reason.MISSING_METADATA
115
+
116
+ # Messages array
117
+ messages = example.get("messages", [])
118
+ if not isinstance(messages, list) or len(messages) == 0:
119
+ return Reason.NO_MESSAGES
120
+
121
+ # Role validation
122
+ for msg in messages:
123
+ if not isinstance(msg, dict) or "role" not in msg or "content" not in msg:
124
+ return Reason.BAD_ROLES
125
+ if msg["role"] not in VALID_ROLES:
126
+ return Reason.BAD_ROLES
127
+
128
+ return None
129
+
130
+
131
+ def get_assistant_content(example: dict) -> str:
132
+ """Extract concatenated assistant message content."""
133
+ parts = []
134
+ for msg in example.get("messages", []):
135
+ if msg.get("role") == "assistant":
136
+ parts.append(msg.get("content", ""))
137
+ return "\n".join(parts)
138
+
139
+
140
+ def check_assistant_exists(example: dict) -> str | None:
141
+ """Must have at least one assistant message with non-empty content."""
142
+ content = get_assistant_content(example)
143
+ if not content:
144
+ return Reason.NO_ASSISTANT
145
+ if len(content.strip()) < 10:
146
+ return Reason.EMPTY_ASSISTANT
147
+ return None
148
+
149
+
150
+ def check_token_length(example: dict, min_tokens: int, max_tokens: int) -> str | None:
151
+ """Filter by token count stored in metadata."""
152
+ tokens = example.get("metadata", {}).get("tokens", 0)
153
+ if tokens < min_tokens:
154
+ return Reason.TOO_SHORT
155
+ if tokens > max_tokens:
156
+ return Reason.TOO_LONG
157
+ return None
158
+
159
+
160
+ def compute_mindi_bonus(example: dict) -> float:
161
+ """Compute bonus score for MINDI agentic special tokens."""
162
+ content = get_assistant_content(example)
163
+ bonus = 0.0
164
+ for token, value in MINDI_TOKEN_BONUSES.items():
165
+ if token in content:
166
+ bonus += value
167
+ return bonus
168
+
169
+
170
+ def check_quality_score(example: dict, min_quality: float) -> str | None:
171
+ """Filter by quality_score + MINDI token bonus."""
172
+ score = example.get("metadata", {}).get("quality_score", 0.0)
173
+ score += compute_mindi_bonus(example)
174
+ if score < min_quality:
175
+ return Reason.LOW_QUALITY
176
+ return None
177
+
178
+
179
+ def check_code_tokens(example: dict) -> str | None:
180
+ """Assistant content must contain code_start/code_end pair."""
181
+ content = get_assistant_content(example)
182
+
183
+ for start_tok, end_tok in CODE_TOKEN_PAIRS:
184
+ if start_tok in content and end_tok in content:
185
+ # Check ordering: start before end
186
+ if content.index(start_tok) < content.rindex(end_tok):
187
+ return None # OK
188
+
189
+ return Reason.NO_CODE_TOKENS
190
+
191
+
192
+ def check_unmatched_tokens(example: dict) -> str | None:
193
+ """Ensure all special token pairs are properly matched (start count == end count)."""
194
+ content = get_assistant_content(example)
195
+ all_pairs = CODE_TOKEN_PAIRS + OPTIONAL_TOKEN_PAIRS
196
+
197
+ for start_tok, end_tok in all_pairs:
198
+ start_count = content.count(start_tok)
199
+ end_count = content.count(end_tok)
200
+ if start_count != end_count:
201
+ return Reason.UNMATCHED_TOKENS
202
+
203
+ return None
204
+
205
+
206
+ def check_boilerplate(example: dict) -> str | None:
207
+ """Detect boilerplate/placeholder assistant responses."""
208
+ content = get_assistant_content(example)
209
+ content_lower = content.lower().strip()
210
+
211
+ # Very short code blocks (just placeholder)
212
+ code_markers = ("<|code_start|>", "<|code_end|>")
213
+ if code_markers[0] in content and code_markers[1] in content:
214
+ start_idx = content.index(code_markers[0]) + len(code_markers[0])
215
+ end_idx = content.index(code_markers[1])
216
+ code_body = content[start_idx:end_idx].strip()
217
+ if len(code_body) < 5:
218
+ return Reason.BOILERPLATE
219
+
220
+ # Repetitive content (same char repeated)
221
+ stripped = content_lower.replace(" ", "").replace("\n", "")
222
+ if len(stripped) > 20:
223
+ unique_chars = len(set(stripped))
224
+ if unique_chars < 5:
225
+ return Reason.BOILERPLATE
226
+
227
+ return None
228
+
229
+
230
+ def content_hash(example: dict) -> str:
231
+ """SHA-256 hash of assistant content for deduplication."""
232
+ content = get_assistant_content(example)
233
+ return hashlib.sha256(content.encode("utf-8", errors="replace")).hexdigest()
234
+
235
+
236
+ # ── Main pipeline ─────────────────────────────────────────────────────
237
+
238
+ def run_filter(
239
+ dry_run: bool = False,
240
+ min_tokens: int = 50,
241
+ max_tokens: int = 4096,
242
+ min_quality: float = 5.0,
243
+ ) -> None:
244
+ """Run the full quality filter pipeline."""
245
+
246
+ if not INPUT_FILE.exists():
247
+ print(f"ERROR: Input file not found: {INPUT_FILE}")
248
+ sys.exit(1)
249
+
250
+ # Count input lines
251
+ print(f"Counting input examples from {INPUT_FILE.name} ...")
252
+ total_input = sum(1 for _ in open(INPUT_FILE, "r", encoding="utf-8"))
253
+ print(f" Total input: {total_input:,} examples")
254
+ print()
255
+
256
+ # Filter settings
257
+ print("Filter settings:")
258
+ print(f" Min tokens: {min_tokens}")
259
+ print(f" Max tokens: {max_tokens}")
260
+ print(f" Min quality: {min_quality}")
261
+ print(f" Dry run: {dry_run}")
262
+ print()
263
+
264
+ # Stats tracking
265
+ kept = 0
266
+ rejected = 0
267
+ reject_reasons: Counter = Counter()
268
+ source_kept: Counter = Counter()
269
+ source_rejected: Counter = Counter()
270
+ seen_hashes: set[str] = set()
271
+ token_sum = 0
272
+ quality_sum = 0.0
273
+
274
+ # Type distribution
275
+ type_counts: Counter = Counter()
276
+
277
+ # Language distribution
278
+ lang_counts: Counter = Counter()
279
+
280
+ start_time = time.time()
281
+
282
+ out_f = None
283
+ rej_f = None
284
+ if not dry_run:
285
+ OUTPUT_FILE.parent.mkdir(parents=True, exist_ok=True)
286
+ out_f = open(OUTPUT_FILE, "w", encoding="utf-8")
287
+ rej_f = open(REJECT_FILE, "w", encoding="utf-8")
288
+
289
+ try:
290
+ with open(INPUT_FILE, "r", encoding="utf-8") as f:
291
+ for line_num, line in enumerate(f, 1):
292
+ line = line.strip()
293
+ if not line:
294
+ continue
295
+
296
+ # Parse JSON
297
+ try:
298
+ example = json.loads(line)
299
+ except json.JSONDecodeError:
300
+ reject_reasons[Reason.INVALID_JSON] += 1
301
+ rejected += 1
302
+ if rej_f:
303
+ rej_f.write(line + "\n")
304
+ continue
305
+
306
+ source = example.get("source", "unknown")
307
+ is_protected = source in PROTECTED_SOURCES
308
+
309
+ # Run filter chain (order matters: cheapest first)
310
+ # Protected sources: schema + assistant + token length + unmatched only
311
+ # Regular sources: full chain + dedup
312
+ if is_protected:
313
+ rejection = (
314
+ validate_schema(example)
315
+ or check_assistant_exists(example)
316
+ or check_token_length(example, min_tokens, max_tokens)
317
+ or check_unmatched_tokens(example)
318
+ )
319
+ else:
320
+ rejection = (
321
+ validate_schema(example)
322
+ or check_assistant_exists(example)
323
+ or check_token_length(example, min_tokens, max_tokens)
324
+ or check_quality_score(example, min_quality)
325
+ or check_code_tokens(example)
326
+ or check_unmatched_tokens(example)
327
+ or check_boilerplate(example)
328
+ )
329
+
330
+ if rejection is None and not is_protected:
331
+ # Dedup check (skip for protected sources)
332
+ h = content_hash(example)
333
+ if h in seen_hashes:
334
+ rejection = Reason.DUPLICATE
335
+
336
+ if rejection is not None:
337
+ reject_reasons[rejection] += 1
338
+ source_rejected[source] += 1
339
+ rejected += 1
340
+ if rej_f:
341
+ rej_f.write(line + "\n")
342
+ continue
343
+
344
+ # Passed all filters
345
+ if not is_protected:
346
+ seen_hashes.add(h)
347
+ kept += 1
348
+ source_kept[source] += 1
349
+ token_sum += example.get("metadata", {}).get("tokens", 0)
350
+ quality_sum += example.get("metadata", {}).get("quality_score", 0.0)
351
+ type_counts[example.get("type", "unknown")] += 1
352
+ lang_counts[example.get("metadata", {}).get("language", "unknown")] += 1
353
+
354
+ if out_f:
355
+ out_f.write(line + "\n")
356
+
357
+ # Progress
358
+ if line_num % 50000 == 0:
359
+ elapsed = time.time() - start_time
360
+ rate = line_num / elapsed if elapsed > 0 else 0
361
+ pct = (line_num / total_input) * 100
362
+ print(f" [{pct:5.1f}%] Processed {line_num:>10,} | Kept {kept:>10,} | Rejected {rejected:>10,} | {rate:,.0f} ex/s")
363
+
364
+ finally:
365
+ if out_f:
366
+ out_f.close()
367
+ if rej_f:
368
+ rej_f.close()
369
+
370
+ elapsed = time.time() - start_time
371
+
372
+ # ── Summary report ────────────────────────────────────────────
373
+ print()
374
+ print("=" * 60)
375
+ print(" QUALITY FILTER REPORT")
376
+ print("=" * 60)
377
+ print(f" Input: {total_input:>10,} examples")
378
+ print(f" Kept: {kept:>10,} examples ({kept/total_input*100:.1f}%)")
379
+ print(f" Rejected: {rejected:>10,} examples ({rejected/total_input*100:.1f}%)")
380
+ print(f" Time: {elapsed:>10.1f} seconds")
381
+ print(f" Rate: {total_input/elapsed:>10,.0f} examples/sec")
382
+ print()
383
+
384
+ if kept > 0:
385
+ print(f" Avg tokens: {token_sum/kept:>10.0f}")
386
+ print(f" Avg quality: {quality_sum/kept:>10.2f}")
387
+ print(f" Total tokens:{token_sum:>10,}")
388
+ print()
389
+
390
+ # Rejection breakdown
391
+ print(" Rejection breakdown:")
392
+ for reason, count in reject_reasons.most_common():
393
+ pct = count / total_input * 100
394
+ print(f" {reason:<30s} {count:>10,} ({pct:.1f}%)")
395
+ print()
396
+
397
+ # Source breakdown
398
+ print(" Source breakdown (kept / total):")
399
+ all_sources = sorted(set(list(source_kept.keys()) + list(source_rejected.keys())))
400
+ for src in all_sources:
401
+ k = source_kept.get(src, 0)
402
+ total = k + source_rejected.get(src, 0)
403
+ pct = k / total * 100 if total > 0 else 0
404
+ print(f" {src:<25s} {k:>8,} / {total:>8,} ({pct:.1f}%)")
405
+ print()
406
+
407
+ # Type distribution
408
+ print(" Type distribution (kept):")
409
+ for t, c in type_counts.most_common(10):
410
+ print(f" {t:<25s} {c:>8,}")
411
+ print()
412
+
413
+ # Language distribution (top 15)
414
+ print(" Language distribution (kept, top 15):")
415
+ for lang, c in lang_counts.most_common(15):
416
+ print(f" {lang:<25s} {c:>8,}")
417
+ print()
418
+
419
+ if not dry_run:
420
+ print(f" Output: {OUTPUT_FILE}")
421
+ print(f" Rejects: {REJECT_FILE}")
422
+
423
+ # Save machine-readable report
424
+ report = {
425
+ "input_count": total_input,
426
+ "kept_count": kept,
427
+ "rejected_count": rejected,
428
+ "kept_pct": round(kept / total_input * 100, 2),
429
+ "avg_tokens": round(token_sum / kept, 1) if kept > 0 else 0,
430
+ "avg_quality": round(quality_sum / kept, 3) if kept > 0 else 0,
431
+ "total_tokens": token_sum,
432
+ "elapsed_seconds": round(elapsed, 1),
433
+ "filter_settings": {
434
+ "min_tokens": min_tokens,
435
+ "max_tokens": max_tokens,
436
+ "min_quality": min_quality,
437
+ },
438
+ "rejection_breakdown": dict(reject_reasons.most_common()),
439
+ "source_kept": dict(source_kept),
440
+ "source_rejected": dict(source_rejected),
441
+ "type_distribution": dict(type_counts.most_common()),
442
+ "language_distribution": dict(lang_counts.most_common(30)),
443
+ }
444
+ with open(REPORT_FILE, "w", encoding="utf-8") as rf:
445
+ json.dump(report, rf, indent=2)
446
+ print(f" Report: {REPORT_FILE}")
447
+
448
+ print("=" * 60)
449
+
450
+
451
+ # ── CLI ───────────────────────────────────────────────────────────────
452
+
453
+ def main():
454
+ parser = argparse.ArgumentParser(
455
+ description="MINDI Quality Filter β€” remove low-quality training examples",
456
+ )
457
+ parser.add_argument("--dry-run", action="store_true", help="Preview counts without writing output")
458
+ parser.add_argument("--min-tokens", type=int, default=50, help="Minimum token count (default: 50)")
459
+ parser.add_argument("--max-tokens", type=int, default=4096, help="Maximum token count (default: 4096)")
460
+ parser.add_argument("--min-quality", type=float, default=5.0, help="Minimum quality_score (default: 5.0)")
461
+
462
+ args = parser.parse_args()
463
+ run_filter(
464
+ dry_run=args.dry_run,
465
+ min_tokens=args.min_tokens,
466
+ max_tokens=args.max_tokens,
467
+ min_quality=args.min_quality,
468
+ )
469
+
470
+
471
+ if __name__ == "__main__":
472
+ main()
scripts/split_data.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ MINDI 1.5 Vision-Coder β€” Train / Validation / Test Split
4
+
5
+ Splits mindi_filtered.jsonl into:
6
+ - train.jsonl (90%)
7
+ - val.jsonl (5%)
8
+ - test.jsonl (5%)
9
+
10
+ Stratified by source to ensure proportional representation.
11
+ Deterministic with a fixed random seed.
12
+
13
+ Usage:
14
+ python scripts/split_data.py # Default 90/5/5
15
+ python scripts/split_data.py --train 0.85 --val 0.10 --test 0.05
16
+ python scripts/split_data.py --seed 42
17
+ python scripts/split_data.py --dry-run
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import argparse
23
+ import json
24
+ import random
25
+ import sys
26
+ import time
27
+ from collections import Counter
28
+ from pathlib import Path
29
+
30
+ # ── Paths ─────────────────────────────────────────────────────────────
31
+
32
+ PROJECT_ROOT = Path(__file__).resolve().parent.parent
33
+ INPUT_FILE = PROJECT_ROOT / "data" / "processed" / "mindi_filtered.jsonl"
34
+ OUTPUT_DIR = PROJECT_ROOT / "data" / "processed"
35
+ TRAIN_FILE = OUTPUT_DIR / "train.jsonl"
36
+ VAL_FILE = OUTPUT_DIR / "val.jsonl"
37
+ TEST_FILE = OUTPUT_DIR / "test.jsonl"
38
+
39
+
40
+ def run_split(
41
+ train_ratio: float = 0.90,
42
+ val_ratio: float = 0.05,
43
+ test_ratio: float = 0.05,
44
+ seed: int = 42,
45
+ dry_run: bool = False,
46
+ ) -> None:
47
+ """Split filtered data into train/val/test with stratification by source."""
48
+
49
+ # Validate ratios
50
+ total_ratio = train_ratio + val_ratio + test_ratio
51
+ if abs(total_ratio - 1.0) > 0.001:
52
+ print(f"ERROR: Ratios must sum to 1.0, got {total_ratio:.3f}")
53
+ sys.exit(1)
54
+
55
+ if not INPUT_FILE.exists():
56
+ print(f"ERROR: Input file not found: {INPUT_FILE}")
57
+ print(" Run quality_filter.py first to generate mindi_filtered.jsonl")
58
+ sys.exit(1)
59
+
60
+ print(f"Loading examples from {INPUT_FILE.name} ...")
61
+ start = time.time()
62
+
63
+ # Group lines by source for stratified splitting
64
+ source_lines: dict[str, list[str]] = {}
65
+ total = 0
66
+ with open(INPUT_FILE, "r", encoding="utf-8") as f:
67
+ for line in f:
68
+ line = line.strip()
69
+ if not line:
70
+ continue
71
+ total += 1
72
+ try:
73
+ example = json.loads(line)
74
+ source = example.get("source", "unknown")
75
+ except json.JSONDecodeError:
76
+ source = "unknown"
77
+ source_lines.setdefault(source, []).append(line)
78
+
79
+ load_time = time.time() - start
80
+ print(f" Loaded {total:,} examples in {load_time:.1f}s")
81
+ print(f" Sources: {len(source_lines)}")
82
+ print()
83
+
84
+ # Split settings
85
+ print(f"Split ratios: train={train_ratio:.0%} / val={val_ratio:.0%} / test={test_ratio:.0%}")
86
+ print(f"Random seed: {seed}")
87
+ print(f"Dry run: {dry_run}")
88
+ print()
89
+
90
+ rng = random.Random(seed)
91
+
92
+ train_lines: list[str] = []
93
+ val_lines: list[str] = []
94
+ test_lines: list[str] = []
95
+
96
+ source_stats: dict[str, dict[str, int]] = {}
97
+
98
+ for source in sorted(source_lines.keys()):
99
+ lines = source_lines[source]
100
+ rng.shuffle(lines)
101
+
102
+ n = len(lines)
103
+ n_val = max(1, round(n * val_ratio)) if n >= 3 else 0
104
+ n_test = max(1, round(n * test_ratio)) if n >= 3 else 0
105
+ n_train = n - n_val - n_test
106
+
107
+ # Edge case: if too few examples, put all in train
108
+ if n < 3:
109
+ n_train = n
110
+ n_val = 0
111
+ n_test = 0
112
+
113
+ train_lines.extend(lines[:n_train])
114
+ val_lines.extend(lines[n_train:n_train + n_val])
115
+ test_lines.extend(lines[n_train + n_val:])
116
+
117
+ source_stats[source] = {
118
+ "total": n,
119
+ "train": n_train,
120
+ "val": n_val,
121
+ "test": n_test,
122
+ }
123
+
124
+ # Shuffle final lists (so sources are interleaved)
125
+ rng.shuffle(train_lines)
126
+ rng.shuffle(val_lines)
127
+ rng.shuffle(test_lines)
128
+
129
+ # ── Summary ───────────────────────────────────────────────────
130
+ print("=" * 60)
131
+ print(" SPLIT SUMMARY")
132
+ print("=" * 60)
133
+ print(f" Total: {total:>10,}")
134
+ print(f" Train: {len(train_lines):>10,} ({len(train_lines)/total*100:.1f}%)")
135
+ print(f" Validation: {len(val_lines):>10,} ({len(val_lines)/total*100:.1f}%)")
136
+ print(f" Test: {len(test_lines):>10,} ({len(test_lines)/total*100:.1f}%)")
137
+ print()
138
+
139
+ print(" Per-source breakdown:")
140
+ print(f" {'Source':<25s} {'Total':>8s} {'Train':>8s} {'Val':>8s} {'Test':>8s}")
141
+ print(f" {'-'*25} {'-'*8} {'-'*8} {'-'*8} {'-'*8}")
142
+ for source in sorted(source_stats.keys()):
143
+ s = source_stats[source]
144
+ print(f" {source:<25s} {s['total']:>8,} {s['train']:>8,} {s['val']:>8,} {s['test']:>8,}")
145
+ print()
146
+
147
+ if not dry_run:
148
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
149
+
150
+ print("Writing files ...")
151
+ for path, lines, name in [
152
+ (TRAIN_FILE, train_lines, "train"),
153
+ (VAL_FILE, val_lines, "val"),
154
+ (TEST_FILE, test_lines, "test"),
155
+ ]:
156
+ with open(path, "w", encoding="utf-8") as f:
157
+ for line in lines:
158
+ f.write(line + "\n")
159
+ size_mb = path.stat().st_size / (1024 * 1024)
160
+ print(f" {name:<12s} β†’ {path.name:<20s} ({len(lines):>10,} examples, {size_mb:>8.1f} MB)")
161
+
162
+ # Save split metadata
163
+ meta = {
164
+ "total": total,
165
+ "train_count": len(train_lines),
166
+ "val_count": len(val_lines),
167
+ "test_count": len(test_lines),
168
+ "train_pct": round(len(train_lines) / total * 100, 2),
169
+ "val_pct": round(len(val_lines) / total * 100, 2),
170
+ "test_pct": round(len(test_lines) / total * 100, 2),
171
+ "seed": seed,
172
+ "source_breakdown": source_stats,
173
+ }
174
+ meta_path = OUTPUT_DIR / "split_meta.json"
175
+ with open(meta_path, "w", encoding="utf-8") as f:
176
+ json.dump(meta, f, indent=2)
177
+ print(f" Metadata β†’ {meta_path.name}")
178
+
179
+ elapsed = time.time() - start
180
+ print(f"\n Done in {elapsed:.1f}s")
181
+ print("=" * 60)
182
+
183
+
184
+ # ── CLI ───────────────────────────────────────────────────────────────
185
+
186
+ def main():
187
+ parser = argparse.ArgumentParser(
188
+ description="MINDI Data Splitter β€” stratified train/val/test split",
189
+ )
190
+ parser.add_argument("--train", type=float, default=0.90, help="Train ratio (default: 0.90)")
191
+ parser.add_argument("--val", type=float, default=0.05, help="Validation ratio (default: 0.05)")
192
+ parser.add_argument("--test", type=float, default=0.05, help="Test ratio (default: 0.05)")
193
+ parser.add_argument("--seed", type=int, default=42, help="Random seed (default: 42)")
194
+ parser.add_argument("--dry-run", action="store_true", help="Preview split without writing files")
195
+
196
+ args = parser.parse_args()
197
+ run_split(
198
+ train_ratio=args.train,
199
+ val_ratio=args.val,
200
+ test_ratio=args.test,
201
+ seed=args.seed,
202
+ dry_run=args.dry_run,
203
+ )
204
+
205
+
206
+ if __name__ == "__main__":
207
+ main()
scripts/train.py CHANGED
@@ -1,39 +1,348 @@
 
1
  """
2
- MINDI 1.5 Vision-Coder β€” Training Launch Script
3
 
4
- Entry point for starting LoRA fine-tuning.
5
- Loads config, initializes model + dataset, and runs training.
 
 
 
 
 
 
6
  """
7
 
8
  from __future__ import annotations
9
 
10
  import argparse
 
 
 
11
  from pathlib import Path
12
 
 
 
 
13
 
14
- def main() -> None:
15
- """Parse args and launch training."""
16
- parser = argparse.ArgumentParser(description="MINDI 1.5 β€” Launch LoRA Training")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  parser.add_argument(
18
- "--config", type=str, default="./configs/training_config.yaml",
 
19
  help="Path to training config YAML",
20
  )
21
  parser.add_argument(
22
- "--local", action="store_true", default=True,
23
- help="Use local GPU overrides (RTX 4060 mode)",
24
  )
25
  parser.add_argument(
26
- "--cloud", action="store_true",
27
- help="Use cloud GPU settings (MI300X mode)",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  )
29
- args = parser.parse_args()
30
 
31
- local_mode = not args.cloud
32
- config_path = Path(args.config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- print(f"[MINDI Training] Config: {config_path}")
35
- print(f"[MINDI Training] Mode: {'local (RTX 4060)' if local_mode else 'cloud (MI300X)'}")
36
- print("[MINDI Training] Pipeline will be wired after Phase 3 setup.")
 
 
 
 
 
 
 
 
37
 
38
 
39
  if __name__ == "__main__":
 
1
+ #!/usr/bin/env python3
2
  """
3
+ MINDI 1.5 Vision-Coder β€” Master Training Script
4
 
5
+ Usage:
6
+ python scripts/train.py --phase 1 # Run phase 1 only
7
+ python scripts/train.py --phase all # Run all 3 phases
8
+ python scripts/train.py --phase 2 --resume checkpoints/training/phase1_lora_step5000
9
+ python scripts/train.py --dry_run # Test 10 steps only
10
+ python scripts/train.py --push_to_hub # Upload after training
11
+
12
+ Handles Ctrl+C gracefully: saves checkpoint before exit.
13
  """
14
 
15
  from __future__ import annotations
16
 
17
  import argparse
18
+ import signal
19
+ import sys
20
+ import traceback
21
  from pathlib import Path
22
 
23
+ # Resolve project root (scripts/ is one level deep)
24
+ PROJECT_ROOT = Path(__file__).resolve().parent.parent
25
+ sys.path.insert(0, str(PROJECT_ROOT))
26
 
27
+ import torch
28
+ import yaml
29
+
30
+
31
+ def parse_args() -> argparse.Namespace:
32
+ parser = argparse.ArgumentParser(
33
+ description="MINDI 1.5 Vision-Coder β€” Training",
34
+ formatter_class=argparse.RawDescriptionHelpFormatter,
35
+ )
36
+ parser.add_argument(
37
+ "--phase", type=str, default="all",
38
+ choices=["1", "2", "3", "all"],
39
+ help="Which phase(s) to run: 1, 2, 3, or all (default: all)",
40
+ )
41
+ parser.add_argument(
42
+ "--resume", type=str, default=None,
43
+ help="Path to checkpoint directory to resume from",
44
+ )
45
  parser.add_argument(
46
+ "--config", type=str,
47
+ default=str(PROJECT_ROOT / "configs" / "training_config.yaml"),
48
  help="Path to training config YAML",
49
  )
50
  parser.add_argument(
51
+ "--dry_run", action="store_true",
52
+ help="Test run: only 10 steps per phase",
53
  )
54
  parser.add_argument(
55
+ "--push_to_hub", action="store_true",
56
+ help="Push checkpoints to HuggingFace after each phase",
57
+ )
58
+ parser.add_argument(
59
+ "--no_wandb", action="store_true",
60
+ help="Disable WandB logging",
61
+ )
62
+ return parser.parse_args()
63
+
64
+
65
+ def load_config(config_path: str) -> dict:
66
+ """Load and return the training config YAML."""
67
+ path = Path(config_path)
68
+ if not path.exists():
69
+ raise FileNotFoundError(f"Config not found: {path}")
70
+ with open(path, "r", encoding="utf-8") as f:
71
+ return yaml.safe_load(f)
72
+
73
+
74
+ def build_training_config(raw: dict, dry_run: bool = False):
75
+ """Build TrainingConfig from parsed YAML."""
76
+ from src.training.mindi_trainer import PhaseConfig, TrainingConfig
77
+
78
+ training = raw.get("training", {})
79
+ data = raw.get("data", {})
80
+ output = raw.get("output", {})
81
+ logging_cfg = raw.get("logging", {})
82
+ model_cfg = raw.get("model", {})
83
+
84
+ # Build phase configs from YAML
85
+ phases = []
86
+ phase_defs = [
87
+ ("phase1", "phase1_lora", True, False, False),
88
+ ("phase2", "phase2_vision_bridge", False, True, True),
89
+ ("phase3", "phase3_all", True, True, True),
90
+ ]
91
+ cumulative_step = 0
92
+ for key, name, lora, vision, fusion in phase_defs:
93
+ pcfg = training.get(key, {})
94
+ steps = pcfg.get("steps", 2500)
95
+ if dry_run:
96
+ steps = 10
97
+ start = cumulative_step
98
+ end = cumulative_step + steps
99
+ phases.append(PhaseConfig(
100
+ name=name,
101
+ start_step=start,
102
+ end_step=end,
103
+ learning_rate=float(pcfg.get("lr", 2e-4)),
104
+ batch_size=pcfg.get("batch_size", 8),
105
+ gradient_accumulation_steps=training.get("grad_accumulation", 4),
106
+ lora=lora,
107
+ vision_projection=vision,
108
+ fusion=fusion,
109
+ ))
110
+ cumulative_step = end
111
+
112
+ config = TrainingConfig(
113
+ train_file=PROJECT_ROOT / data.get("train_file", "data/processed/train.jsonl"),
114
+ val_file=PROJECT_ROOT / data.get("val_file", "data/processed/val.jsonl"),
115
+ output_dir=PROJECT_ROOT / output.get("checkpoint_dir", "checkpoints/training"),
116
+ log_dir=PROJECT_ROOT / logging_cfg.get("log_dir", "logs/training"),
117
+ max_seq_length=data.get("max_length", 4096),
118
+ use_compile=model_cfg.get("use_compile", False),
119
+ gradient_checkpointing=model_cfg.get("gradient_checkpointing", True),
120
+ dtype=model_cfg.get("dtype", "bf16"),
121
+ num_workers=data.get("num_workers", 4),
122
+ pin_memory=True,
123
+ prefetch_factor=2,
124
+ weight_decay=0.01,
125
+ warmup_ratio=0.03,
126
+ max_grad_norm=float(training.get("max_grad_norm", 1.0)),
127
+ seed=42,
128
+ log_every_n_steps=logging_cfg.get("log_every", 10),
129
+ eval_every_n_steps=training.get("eval_every", 250),
130
+ save_every_n_steps=training.get("save_every", 500),
131
+ phases=phases,
132
+ )
133
+
134
+ if dry_run:
135
+ config.eval_every_n_steps = 5
136
+ config.save_every_n_steps = 10
137
+ config.log_every_n_steps = 1
138
+
139
+ return config
140
+
141
+
142
+ def init_wandb(raw_config: dict, phase: str, disabled: bool = False):
143
+ """Initialize WandB logging."""
144
+ if disabled:
145
+ return None
146
+ try:
147
+ import wandb
148
+ logging_cfg = raw_config.get("logging", {})
149
+ run = wandb.init(
150
+ project=logging_cfg.get("wandb_project", "mindi-1.5-vision-coder"),
151
+ entity=logging_cfg.get("wandb_entity", "mindigenous"),
152
+ name=f"mindi15-{phase}",
153
+ config=raw_config,
154
+ tags=["mindi-1.5", "training", f"phase-{phase}"],
155
+ reinit=True,
156
+ )
157
+ print(f"[train.py] WandB initialized: {run.url}")
158
+ return run
159
+ except ImportError:
160
+ print("[train.py] WandB not installed β€” logging disabled")
161
+ return None
162
+ except Exception as e:
163
+ print(f"[train.py] WandB init failed: {e} β€” continuing without logging")
164
+ return None
165
+
166
+
167
+ def push_checkpoint_to_hub(checkpoint_dir: Path, raw_config: dict) -> None:
168
+ """Push a checkpoint to HuggingFace Hub."""
169
+ output = raw_config.get("output", {})
170
+ repo_id = output.get("hf_repo", "Mindigenous/MINDI-1.5-Vision-Coder")
171
+
172
+ try:
173
+ from huggingface_hub import HfApi
174
+ import os
175
+ api = HfApi(token=os.environ.get("HF_TOKEN"))
176
+
177
+ print(f"[train.py] Pushing checkpoint to {repo_id} ...")
178
+ api.upload_folder(
179
+ folder_path=str(checkpoint_dir),
180
+ repo_id=repo_id,
181
+ path_in_repo=f"checkpoints/{checkpoint_dir.name}",
182
+ repo_type="model",
183
+ )
184
+ print(f"[train.py] Pushed to https://huggingface.co/{repo_id}")
185
+ except ImportError:
186
+ print("[train.py] huggingface_hub not installed β€” skipping push")
187
+ except Exception as e:
188
+ print(f"[train.py] Push to hub failed: {e}")
189
+
190
+
191
+ def log_wandb_phase_complete(wandb_run, summary: dict) -> None:
192
+ """Log phase completion to WandB."""
193
+ if wandb_run is None:
194
+ return
195
+ try:
196
+ import wandb
197
+ wandb.log({
198
+ "phase_complete": True,
199
+ "phase": summary.get("phase", "unknown"),
200
+ "total_steps": summary.get("total_steps", 0),
201
+ "best_val_loss": summary.get("best_val_loss", 0),
202
+ "elapsed_minutes": summary.get("elapsed_minutes", 0),
203
+ })
204
+ except Exception:
205
+ pass
206
+
207
+
208
+ def main() -> None:
209
+ args = parse_args()
210
+
211
+ print()
212
+ print("=" * 60)
213
+ print(" MINDI 1.5 Vision-Coder β€” Training Launch")
214
+ print(" MINDIGENOUS.AI")
215
+ print("=" * 60)
216
+ print()
217
+ print(f" Phase: {args.phase}")
218
+ print(f" Config: {args.config}")
219
+ print(f" Resume: {args.resume or 'None'}")
220
+ print(f" Dry run: {args.dry_run}")
221
+ print(f" Push to hub: {args.push_to_hub}")
222
+ print(f" Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")
223
+ if torch.cuda.is_available():
224
+ print(f" GPU: {torch.cuda.get_device_name(0)}")
225
+ vram_gb = torch.cuda.get_device_properties(0).total_mem / (1024 ** 3)
226
+ print(f" VRAM: {vram_gb:.1f} GB")
227
+ print()
228
+
229
+ # Load config
230
+ raw_config = load_config(args.config)
231
+ config = build_training_config(raw_config, dry_run=args.dry_run)
232
+
233
+ # Filter phases based on --phase arg
234
+ if args.phase != "all":
235
+ phase_idx = int(args.phase) - 1
236
+ if phase_idx < 0 or phase_idx >= len(config.phases):
237
+ print(f"ERROR: Invalid phase {args.phase}. Available: 1-{len(config.phases)}")
238
+ sys.exit(1)
239
+ selected_phase = config.phases[phase_idx]
240
+ # Adjust to start from 0 for single-phase run
241
+ step_count = selected_phase.end_step - selected_phase.start_step
242
+ selected_phase.start_step = 0
243
+ selected_phase.end_step = step_count
244
+ config.phases = [selected_phase]
245
+
246
+ # Initialize model
247
+ print("[train.py] Initializing MINDI 1.5 model ...")
248
+ from src.model.mindi_model import MINDI15
249
+ model_cfg = raw_config.get("model", {})
250
+ vision_cfg = raw_config.get("vision", {})
251
+
252
+ model = MINDI15(
253
+ model_name=model_cfg.get("name", "Qwen/Qwen2.5-Coder-7B-Instruct"),
254
+ clip_model=vision_cfg.get("clip_model", "openai/clip-vit-large-patch14"),
255
+ hidden_size=model_cfg.get("hidden_size", 4096),
256
+ num_visual_tokens=vision_cfg.get("visual_tokens", 256),
257
+ torch_dtype=config.torch_dtype,
258
  )
 
259
 
260
+ # Initialize trainer
261
+ from src.training.mindi_trainer import MINDITrainer
262
+ trainer = MINDITrainer(model=model, config=config)
263
+
264
+ # Resume from checkpoint
265
+ if args.resume:
266
+ resume_path = Path(args.resume)
267
+ if not resume_path.is_absolute():
268
+ resume_path = PROJECT_ROOT / resume_path
269
+ trainer.resume_from_checkpoint(resume_path)
270
+
271
+ # Initialize WandB
272
+ wandb_run = init_wandb(raw_config, args.phase, disabled=args.no_wandb)
273
+
274
+ # Graceful Ctrl+C handler
275
+ interrupted = False
276
+
277
+ def signal_handler(sig, frame):
278
+ nonlocal interrupted
279
+ if interrupted:
280
+ print("\n[train.py] Forced exit!")
281
+ sys.exit(1)
282
+ interrupted = True
283
+ print("\n[train.py] Ctrl+C received β€” saving checkpoint before exit ...")
284
+ try:
285
+ emergency_dir = config.output_dir / "emergency_checkpoint"
286
+ emergency_dir.mkdir(parents=True, exist_ok=True)
287
+ model.save(emergency_dir)
288
+ print(f"[train.py] Emergency checkpoint saved: {emergency_dir}")
289
+ except Exception as e:
290
+ print(f"[train.py] Emergency save failed: {e}")
291
+ sys.exit(0)
292
+
293
+ signal.signal(signal.SIGINT, signal_handler)
294
+
295
+ # Run training
296
+ try:
297
+ if args.phase == "all":
298
+ summary = trainer.train()
299
+
300
+ final_dir = config.output_dir / "final"
301
+ if args.push_to_hub:
302
+ push_checkpoint_to_hub(final_dir, raw_config)
303
+ log_wandb_phase_complete(wandb_run, summary)
304
+
305
+ else:
306
+ phase = config.phases[0]
307
+ summary = trainer.train_phase(phase)
308
+
309
+ ckpt_dir = config.output_dir / f"{phase.name}_step{phase.end_step}"
310
+ if args.push_to_hub:
311
+ push_checkpoint_to_hub(ckpt_dir, raw_config)
312
+ log_wandb_phase_complete(wandb_run, summary)
313
+
314
+ except KeyboardInterrupt:
315
+ signal_handler(None, None)
316
+ except Exception as e:
317
+ print(f"\n[train.py] ERROR: {e}")
318
+ traceback.print_exc()
319
+ try:
320
+ crash_dir = config.output_dir / "crash_checkpoint"
321
+ crash_dir.mkdir(parents=True, exist_ok=True)
322
+ model.save(crash_dir)
323
+ print(f"[train.py] Crash checkpoint saved: {crash_dir}")
324
+ except Exception:
325
+ pass
326
+ sys.exit(1)
327
+ finally:
328
+ if wandb_run is not None:
329
+ try:
330
+ import wandb
331
+ wandb.finish()
332
+ except Exception:
333
+ pass
334
 
335
+ # Final summary
336
+ hf_repo = raw_config.get("output", {}).get("hf_repo", "Mindigenous/MINDI-1.5-Vision-Coder")
337
+ print()
338
+ print("=" * 60)
339
+ print(" Training complete!")
340
+ print(f" Best val loss: {trainer.best_val_loss:.4f}")
341
+ print(f" Checkpoint at: {config.output_dir}")
342
+ if args.push_to_hub:
343
+ print(f" HuggingFace: https://huggingface.co/{hf_repo}")
344
+ print("=" * 60)
345
+ print()
346
 
347
 
348
  if __name__ == "__main__":
scripts/upload_everything_to_hf.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Upload ENTIRE MINDI 1.5 Vision-Coder project to HuggingFace.
4
+
5
+ REPO 1 (model): Mindigenous/MINDI-1.5-Vision-Coder
6
+ REPO 2 (dataset): Mindigenous/MINDI-1.5-training-data
7
+
8
+ Both private. On MI300X we will clone these repos directly.
9
+ """
10
+
11
+ import os
12
+ import sys
13
+ import time
14
+ from pathlib import Path
15
+
16
+ from dotenv import load_dotenv
17
+ from huggingface_hub import HfApi, create_repo
18
+
19
+ # ── Paths ──────────────────────────────────────────────────────────────
20
+ PROJECT_ROOT = Path(__file__).resolve().parent.parent
21
+ ENV_FILE = PROJECT_ROOT / ".env"
22
+
23
+ # ── Repo names ─────────────────────────────────────────────────────────
24
+ MODEL_REPO = "Mindigenous/MINDI-1.5-Vision-Coder"
25
+ DATASET_REPO = "Mindigenous/MINDI-1.5-training-data"
26
+
27
+ # ── Model card (written to repo as README.md) ─────────────────────────
28
+ MODEL_CARD = """\
29
+ ---
30
+ license: apache-2.0
31
+ language:
32
+ - en
33
+ tags:
34
+ - code-generation
35
+ - nextjs
36
+ - react
37
+ - typescript
38
+ - vision
39
+ - multimodal
40
+ - mindi
41
+ - mindigenous
42
+ base_model: Qwen/Qwen2.5-Coder-7B-Instruct
43
+ ---
44
+
45
+ # MINDI 1.5 Vision-Coder
46
+
47
+ Built by MINDIGENOUS.AI
48
+
49
+ ## Model Description
50
+ MINDI 1.5 is an agentic AI coding model
51
+ that sees its own output and critiques it.
52
+
53
+ ## Key Features
54
+ - Generates Next.js 14 + Tailwind + TypeScript
55
+ - Sees screenshots via CLIP ViT-L/14
56
+ - Critiques its own UI/UX output
57
+ - Searches internet for latest packages
58
+ - Tests code in sandbox environment
59
+ - Self-fixes errors automatically
60
+
61
+ ## Training
62
+ - Base: Qwen/Qwen2.5-Coder-7B-Instruct
63
+ - Method: LoRA fine-tuning
64
+ - Hardware: AMD MI300X 192GB VRAM
65
+ - Dataset: 1,449,428 examples
66
+ - Tokens: 859,694,776
67
+ - Status: Training in progress
68
+
69
+ ## Built By
70
+ Faaz - MINDIGENOUS.AI
71
+ Mumbai, India
72
+ April 2026
73
+ """
74
+
75
+ # ── Dataset card ───────────────────────────────────────────────────────
76
+ DATASET_CARD = """\
77
+ ---
78
+ license: apache-2.0
79
+ language:
80
+ - en
81
+ tags:
82
+ - code-generation
83
+ - nextjs
84
+ - react
85
+ - typescript
86
+ - vision
87
+ - multimodal
88
+ - mindi
89
+ - mindigenous
90
+ size_categories:
91
+ - 1M<n<10M
92
+ ---
93
+
94
+ # MINDI 1.5 Training Data
95
+
96
+ Training dataset for **MINDI 1.5 Vision-Coder** by MINDIGENOUS.AI
97
+
98
+ ## Dataset Statistics
99
+ | Metric | Value |
100
+ |--------|-------|
101
+ | Total examples | 1,449,428 |
102
+ | Total tokens | 859,694,776 |
103
+ | Avg tokens/example | 593 |
104
+ | Avg quality score | 6.49 |
105
+ | Sources | 9 |
106
+
107
+ ## Splits
108
+ | Split | Examples | Percentage |
109
+ |-------|----------|------------|
110
+ | Train | 1,304,486 | 90.0% |
111
+ | Validation | 72,471 | 5.0% |
112
+ | Test | 72,471 | 5.0% |
113
+
114
+ ## Sources
115
+ | Source | Examples | Kept % |
116
+ |--------|----------|--------|
117
+ | starcoderdata | 569,350 | 94.9% |
118
+ | websight | 250,987 | 99.99% |
119
+ | evol_code | 155,998 | 99.7% |
120
+ | codefeedback | 149,865 | 99.9% |
121
+ | magicoder | 149,987 | 99.99% |
122
+ | synthetic_nextjs | 90,000 | 100% (protected) |
123
+ | codealpaca | 59,241 | 98.8% |
124
+ | search_examples | 15,000 | 100% (protected) |
125
+ | sandbox_examples | 9,000 | 100% (protected) |
126
+
127
+ ## Type Distribution
128
+ | Type | Examples |
129
+ |------|----------|
130
+ | code_generation | 1,183,441 |
131
+ | vision_code | 250,987 |
132
+ | search | 15,000 |
133
+
134
+ ## Language Distribution
135
+ | Language | Examples |
136
+ |----------|----------|
137
+ | unknown | 490,305 |
138
+ | typescript | 375,859 |
139
+ | javascript | 298,497 |
140
+ | python | 211,842 |
141
+ | html | 36,371 |
142
+ | java | 32,458 |
143
+ | rust | 3,709 |
144
+ | go | 387 |
145
+
146
+ ## Format
147
+ Each example is a JSON object with:
148
+ - `conversations`: list of `{"role": ..., "content": ...}` turns
149
+ - `source`: dataset origin
150
+ - `type`: code_generation / vision_code / search
151
+ - `language`: programming language
152
+ - `quality_score`: heuristic quality (0-10+)
153
+ - `token_count`: number of tokens
154
+
155
+ ## Quality Filtering
156
+ - Protected sources (sandbox, search, synthetic_nextjs) bypass aggressive filters
157
+ - MINDI special token bonuses boost agentic examples
158
+ - Dedup via SHA-256 content hashing
159
+ - Rejection reasons: too_many_tokens (30,637), boilerplate (1,373), duplicate (59)
160
+
161
+ ## Built By
162
+ Faaz - MINDIGENOUS.AI
163
+ Mumbai, India β€” April 2026
164
+ """
165
+
166
+
167
+ # ────────────────────────────────────────────────────────────────────────
168
+ def load_token() -> str:
169
+ """Load HF token from .env."""
170
+ load_dotenv(ENV_FILE)
171
+ token = os.getenv("HUGGINGFACE_TOKEN") or os.getenv("HF_TOKEN")
172
+ if not token:
173
+ print("ERROR: No HUGGINGFACE_TOKEN or HF_TOKEN found in .env")
174
+ sys.exit(1)
175
+ return token
176
+
177
+
178
+ def ensure_repo(api: HfApi, repo_id: str, repo_type: str, token: str):
179
+ """Create repo if it doesn't exist."""
180
+ try:
181
+ create_repo(
182
+ repo_id=repo_id,
183
+ repo_type=repo_type,
184
+ private=True,
185
+ token=token,
186
+ exist_ok=True,
187
+ )
188
+ print(f" Repo ready: {repo_id} ({repo_type})")
189
+ except Exception as e:
190
+ print(f" Repo create/check: {e}")
191
+
192
+
193
+ def upload_folder(api: HfApi, local: Path, remote: str, repo_id: str,
194
+ repo_type: str, token: str):
195
+ """Upload a local folder to HF repo."""
196
+ if not local.exists():
197
+ print(f" SKIP (not found): {local}")
198
+ return
199
+ label = str(local.relative_to(PROJECT_ROOT))
200
+ print(f" Uploading {label}/ to {repo_type} repo ... ", end="", flush=True)
201
+ t0 = time.time()
202
+ api.upload_folder(
203
+ repo_id=repo_id,
204
+ repo_type=repo_type,
205
+ folder_path=str(local),
206
+ path_in_repo=remote,
207
+ token=token,
208
+ ignore_patterns=["__pycache__", "*.pyc", ".git"],
209
+ )
210
+ print(f"done ({time.time() - t0:.1f}s)")
211
+
212
+
213
+ def upload_file(api: HfApi, local: Path, remote: str, repo_id: str,
214
+ repo_type: str, token: str):
215
+ """Upload a single file to HF repo."""
216
+ if not local.exists():
217
+ print(f" SKIP (not found): {local.name}")
218
+ return
219
+ size_mb = local.stat().st_size / (1024 * 1024)
220
+ label = str(local.relative_to(PROJECT_ROOT))
221
+ print(f" Uploading {label} ({size_mb:.1f} MB) to {repo_type} repo ... ",
222
+ end="", flush=True)
223
+ t0 = time.time()
224
+ api.upload_file(
225
+ repo_id=repo_id,
226
+ repo_type=repo_type,
227
+ path_or_fileobj=str(local),
228
+ path_in_repo=remote,
229
+ token=token,
230
+ )
231
+ print(f"done ({time.time() - t0:.1f}s)")
232
+
233
+
234
+ def upload_readme(api: HfApi, content: str, repo_id: str,
235
+ repo_type: str, token: str):
236
+ """Upload a README.md string to a repo."""
237
+ print(f" Uploading README.md to {repo_type} repo ... ", end="", flush=True)
238
+ api.upload_file(
239
+ repo_id=repo_id,
240
+ repo_type=repo_type,
241
+ path_or_fileobj=content.encode("utf-8"),
242
+ path_in_repo="README.md",
243
+ token=token,
244
+ )
245
+ print("done")
246
+
247
+
248
+ # ────────────────────────────────────────────────────────────────────────
249
+ def main():
250
+ print("=" * 60)
251
+ print(" MINDI 1.5 β€” Upload Everything to HuggingFace")
252
+ print("=" * 60)
253
+ print()
254
+
255
+ token = load_token()
256
+ api = HfApi()
257
+
258
+ # ── Create repos ───────────────────────────────────────────────
259
+ print("[1/4] Creating repositories ...")
260
+ ensure_repo(api, MODEL_REPO, "model", token)
261
+ ensure_repo(api, DATASET_REPO, "dataset", token)
262
+ print()
263
+
264
+ # ── REPO 1: Model (code + configs) ─────────────────────────────
265
+ print("[2/4] Uploading to MODEL repo:", MODEL_REPO)
266
+ print("-" * 50)
267
+
268
+ # Folders
269
+ model_folders = [
270
+ (PROJECT_ROOT / "src", "src"),
271
+ (PROJECT_ROOT / "scripts", "scripts"),
272
+ (PROJECT_ROOT / "configs", "configs"),
273
+ (PROJECT_ROOT / "data" / "tokenizer", "data/tokenizer"),
274
+ (PROJECT_ROOT / "tests", "tests"),
275
+ (PROJECT_ROOT / "api", "api"),
276
+ ]
277
+ for local, remote in model_folders:
278
+ upload_folder(api, local, remote, MODEL_REPO, "model", token)
279
+
280
+ # Single files
281
+ model_files = [
282
+ (PROJECT_ROOT / "requirements.txt", "requirements.txt"),
283
+ (PROJECT_ROOT / "setup.py", "setup.py"),
284
+ (PROJECT_ROOT / "activate_mindi.bat", "activate_mindi.bat"),
285
+ (PROJECT_ROOT / ".env.example", ".env.example"),
286
+ ]
287
+ for local, remote in model_files:
288
+ upload_file(api, local, remote, MODEL_REPO, "model", token)
289
+
290
+ # setup_mi300x.sh
291
+ mi300x_sh = PROJECT_ROOT / "setup_mi300x.sh"
292
+ if mi300x_sh.exists():
293
+ upload_file(api, mi300x_sh, "setup_mi300x.sh", MODEL_REPO, "model", token)
294
+
295
+ # Model card replaces README.md
296
+ upload_readme(api, MODEL_CARD, MODEL_REPO, "model", token)
297
+ print()
298
+
299
+ # ── REPO 2: Dataset ────────────────────────────────────────────
300
+ print("[3/4] Uploading to DATASET repo:", DATASET_REPO)
301
+ print("-" * 50)
302
+
303
+ processed = PROJECT_ROOT / "data" / "processed"
304
+ dataset_files = [
305
+ (processed / "train.jsonl", "processed/train.jsonl"),
306
+ (processed / "val.jsonl", "processed/val.jsonl"),
307
+ (processed / "test.jsonl", "processed/test.jsonl"),
308
+ (processed / "mindi_filtered.jsonl", "processed/mindi_filtered.jsonl"),
309
+ (processed / "filter_report.json", "processed/filter_report.json"),
310
+ (processed / "split_meta.json", "processed/split_meta.json"),
311
+ ]
312
+ for local, remote in dataset_files:
313
+ upload_file(api, local, remote, DATASET_REPO, "dataset", token)
314
+
315
+ # Raw data folder
316
+ upload_folder(
317
+ api, PROJECT_ROOT / "data" / "raw", "raw",
318
+ DATASET_REPO, "dataset", token,
319
+ )
320
+
321
+ # Tokenizer copy in dataset repo
322
+ upload_folder(
323
+ api, PROJECT_ROOT / "data" / "tokenizer", "tokenizer",
324
+ DATASET_REPO, "dataset", token,
325
+ )
326
+
327
+ # Dataset card
328
+ upload_readme(api, DATASET_CARD, DATASET_REPO, "dataset", token)
329
+ print()
330
+
331
+ # ── Done ───────────────────────────────────────────────────────
332
+ print("[4/4] Upload complete!")
333
+ print()
334
+ print("╔══════════════════════════════════════╗")
335
+ print("β•‘ UPLOAD COMPLETE! β•‘")
336
+ print("β•‘ β•‘")
337
+ print("β•‘ Model repo: β•‘")
338
+ print("β•‘ huggingface.co/Mindigenous/ β•‘")
339
+ print("β•‘ MINDI-1.5-Vision-Coder β•‘")
340
+ print("β•‘ β•‘")
341
+ print("β•‘ Dataset repo: β•‘")
342
+ print("β•‘ huggingface.co/datasets/ β•‘")
343
+ print("β•‘ Mindigenous/MINDI-1.5-training-data β•‘")
344
+ print("β•‘ β•‘")
345
+ print("β•‘ On MI300X just run: β•‘")
346
+ print("β•‘ git clone https://huggingface.co/ β•‘")
347
+ print("β•‘ Mindigenous/MINDI-1.5-Vision-Coder β•‘")
348
+ print("β•‘ β•‘")
349
+ print("β•‘ Ready to train! πŸš€ β•‘")
350
+ print("β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•")
351
+
352
+
353
+ if __name__ == "__main__":
354
+ main()
setup_mi300x.sh ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # ============================================================
3
+ # MINDI 1.5 Vision-Coder β€” MI300X Setup Script
4
+ # One command to set up everything on DigitalOcean AMD MI300X
5
+ # ============================================================
6
+ set -e
7
+
8
+ echo "============================================================"
9
+ echo " MINDI 1.5 Vision-Coder β€” MI300X Setup"
10
+ echo " MINDIGENOUS.AI"
11
+ echo "============================================================"
12
+ echo ""
13
+
14
+ # ── Check HF_TOKEN ─────────────────────────────────────────────
15
+ if [ -z "$HF_TOKEN" ]; then
16
+ echo "ERROR: Set HF_TOKEN environment variable first!"
17
+ echo " export HF_TOKEN=hf_your_token_here"
18
+ exit 1
19
+ fi
20
+
21
+ # ── Step 1: Install ROCm PyTorch ───────────────────────────────
22
+ echo "[1/7] Installing ROCm PyTorch (ROCm 6.0) ..."
23
+ pip install torch torchvision torchaudio \
24
+ --index-url https://download.pytorch.org/whl/rocm6.0
25
+
26
+ # ── Step 2: Clone the full project from HF ────────────────────
27
+ echo ""
28
+ echo "[2/7] Cloning MINDI 1.5 from HuggingFace ..."
29
+ if [ -d "MINDI-1.5-Vision-Coder" ]; then
30
+ echo " Directory exists β€” pulling latest ..."
31
+ cd MINDI-1.5-Vision-Coder
32
+ git pull
33
+ else
34
+ git clone https://${HF_TOKEN}@huggingface.co/Mindigenous/MINDI-1.5-Vision-Coder
35
+ cd MINDI-1.5-Vision-Coder
36
+ fi
37
+
38
+ # ── Step 3: Install Python requirements ────────────────────────
39
+ echo ""
40
+ echo "[3/7] Installing Python requirements ..."
41
+ pip install -r requirements.txt
42
+
43
+ # Additional training dependencies
44
+ pip install wandb huggingface_hub accelerate
45
+
46
+ # ── Step 4: Download training data from HF ─────────────────────
47
+ echo ""
48
+ echo "[4/7] Downloading training dataset ..."
49
+ python -c "
50
+ from huggingface_hub import snapshot_download
51
+ import os
52
+
53
+ snapshot_download(
54
+ repo_id='Mindigenous/MINDI-1.5-training-data',
55
+ repo_type='dataset',
56
+ local_dir='data/',
57
+ token=os.environ['HF_TOKEN']
58
+ )
59
+ print('Dataset downloaded!')
60
+ "
61
+
62
+ # Verify data files exist
63
+ echo " Checking data files ..."
64
+ if [ ! -f "data/processed/train.jsonl" ]; then
65
+ echo " ERROR: train.jsonl not found!"
66
+ exit 1
67
+ fi
68
+ if [ ! -f "data/processed/val.jsonl" ]; then
69
+ echo " ERROR: val.jsonl not found!"
70
+ exit 1
71
+ fi
72
+ TRAIN_SIZE=$(du -sh data/processed/train.jsonl | cut -f1)
73
+ VAL_SIZE=$(du -sh data/processed/val.jsonl | cut -f1)
74
+ echo " train.jsonl: ${TRAIN_SIZE}"
75
+ echo " val.jsonl: ${VAL_SIZE}"
76
+
77
+ # ── Step 5: Set environment variables ──────────────────────────
78
+ echo ""
79
+ echo "[5/7] Setting environment variables ..."
80
+
81
+ # ROCm / PyTorch settings
82
+ export HSA_OVERRIDE_GFX_VERSION=11.0.0
83
+ export PYTORCH_ROCM_ARCH="gfx942"
84
+ export HIP_VISIBLE_DEVICES=0
85
+ export TOKENIZERS_PARALLELISM=false
86
+ export WANDB_PROJECT="mindi-1.5-vision-coder"
87
+
88
+ # Create .env file
89
+ cat > .env << EOF
90
+ HF_TOKEN=${HF_TOKEN}
91
+ HSA_OVERRIDE_GFX_VERSION=11.0.0
92
+ PYTORCH_ROCM_ARCH=gfx942
93
+ HIP_VISIBLE_DEVICES=0
94
+ TOKENIZERS_PARALLELISM=false
95
+ WANDB_PROJECT=mindi-1.5-vision-coder
96
+ EOF
97
+ echo " .env file created"
98
+
99
+ # ── Step 6: Verify GPU detected ───────────────────────────────
100
+ echo ""
101
+ echo "[6/7] Verifying GPU ..."
102
+ python -c "
103
+ import torch
104
+ print(f' PyTorch version: {torch.__version__}')
105
+ print(f' CUDA available: {torch.cuda.is_available()}')
106
+ if torch.cuda.is_available():
107
+ print(f' GPU name: {torch.cuda.get_device_name(0)}')
108
+ vram = torch.cuda.get_device_properties(0).total_mem / (1024**3)
109
+ print(f' VRAM: {vram:.1f} GB')
110
+ print(f' ROCm backend: {torch.version.hip is not None}')
111
+ else:
112
+ print(' WARNING: No GPU detected!')
113
+ exit(1)
114
+ "
115
+
116
+ # Quick bf16 test
117
+ python -c "
118
+ import torch
119
+ x = torch.randn(100, 100, dtype=torch.bfloat16, device='cuda')
120
+ y = torch.matmul(x, x.T)
121
+ print(f' bf16 matmul test: PASSED (shape={y.shape})')
122
+ "
123
+
124
+ # ── Step 7: Create output directories ─────────────────────────
125
+ echo ""
126
+ echo "[7/7] Creating output directories ..."
127
+ mkdir -p checkpoints/training
128
+ mkdir -p checkpoints/best
129
+ mkdir -p logs/training
130
+
131
+ # ── Done ───────────────────────────────────────────────────────
132
+ echo ""
133
+ echo "============================================================"
134
+ echo " MINDI 1.5 Vision-Coder β€” MI300X Ready!"
135
+ echo ""
136
+ echo " Project: $(pwd)"
137
+ echo " Data: ${TRAIN_SIZE} train / ${VAL_SIZE} val"
138
+ echo " GPU: $(python -c 'import torch; print(torch.cuda.get_device_name(0))' 2>/dev/null || echo 'N/A')"
139
+ echo ""
140
+ echo " Ready to train!"
141
+ echo " Run: python scripts/train.py --phase 1"
142
+ echo ""
143
+ echo " Or dry run first:"
144
+ echo " Run: python scripts/train.py --dry_run --no_wandb"
145
+ echo "============================================================"
src/model/architecture.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MINDI 1.5 Vision-Coder β€” Model Architecture
3
+
4
+ Loads Qwen/Qwen2.5-Coder-7B-Instruct with LoRA adapters.
5
+ Handles model initialization, LoRA application, save/load,
6
+ and parameter counting for the base LLM component.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from pathlib import Path
12
+ from typing import Optional
13
+
14
+ import torch
15
+ from peft import LoraConfig, PeftModel, TaskType, get_peft_model
16
+ from transformers import AutoModelForCausalLM, AutoTokenizer
17
+
18
+
19
+ class MINDIArchitecture:
20
+ """Qwen2.5-Coder-7B-Instruct with LoRA for MINDI 1.5 fine-tuning."""
21
+
22
+ DEFAULT_TARGET_MODULES: list[str] = [
23
+ "q_proj", "k_proj", "v_proj", "o_proj",
24
+ "gate_proj", "up_proj", "down_proj",
25
+ ]
26
+
27
+ def __init__(
28
+ self,
29
+ model_name: str = "Qwen/Qwen2.5-Coder-7B-Instruct",
30
+ device: Optional[str] = None,
31
+ cache_dir: Optional[Path] = None,
32
+ torch_dtype: torch.dtype = torch.bfloat16,
33
+ ) -> None:
34
+ """
35
+ Initialize the architecture wrapper.
36
+
37
+ Args:
38
+ model_name: HuggingFace model identifier.
39
+ device: Target device ('cuda', 'cpu', or None for auto).
40
+ cache_dir: Local directory for model weight cache.
41
+ torch_dtype: Data type for model weights.
42
+ """
43
+ self.model_name = model_name
44
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
45
+ self.cache_dir = Path(cache_dir) if cache_dir else Path("./checkpoints/base")
46
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
47
+ self.torch_dtype = torch_dtype
48
+
49
+ self.model: Optional[AutoModelForCausalLM] = None
50
+ self.peft_model: Optional[PeftModel] = None
51
+ self.tokenizer: Optional[AutoTokenizer] = None
52
+
53
+ self._load_model()
54
+
55
+ def _load_model(self) -> None:
56
+ """Load the base model and tokenizer from HuggingFace or cache."""
57
+ print(f"[MINDIArchitecture] Loading {self.model_name} ...")
58
+ self.model = AutoModelForCausalLM.from_pretrained(
59
+ self.model_name,
60
+ cache_dir=str(self.cache_dir),
61
+ torch_dtype=self.torch_dtype,
62
+ device_map="auto" if self.device == "cuda" else None,
63
+ trust_remote_code=True,
64
+ )
65
+ self.tokenizer = AutoTokenizer.from_pretrained(
66
+ self.model_name,
67
+ cache_dir=str(self.cache_dir),
68
+ trust_remote_code=True,
69
+ )
70
+ print(f"[MINDIArchitecture] Loaded on {self.device} "
71
+ f"({self._fmt_params(self._total_params())} params)")
72
+
73
+ def apply_lora(
74
+ self,
75
+ r: int = 64,
76
+ lora_alpha: int = 128,
77
+ lora_dropout: float = 0.05,
78
+ target_modules: Optional[list[str]] = None,
79
+ ) -> PeftModel:
80
+ """
81
+ Apply LoRA adapters to the base model.
82
+
83
+ Args:
84
+ r: LoRA rank.
85
+ lora_alpha: LoRA scaling factor.
86
+ lora_dropout: Dropout probability for LoRA layers.
87
+ target_modules: List of module names to apply LoRA to.
88
+
89
+ Returns:
90
+ The PEFT-wrapped model.
91
+ """
92
+ if self.model is None:
93
+ raise RuntimeError("Base model not loaded.")
94
+
95
+ if target_modules is None:
96
+ target_modules = self.DEFAULT_TARGET_MODULES
97
+
98
+ lora_config = LoraConfig(
99
+ r=r,
100
+ lora_alpha=lora_alpha,
101
+ lora_dropout=lora_dropout,
102
+ target_modules=target_modules,
103
+ bias="none",
104
+ task_type=TaskType.CAUSAL_LM,
105
+ )
106
+
107
+ self.peft_model = get_peft_model(self.model, lora_config)
108
+
109
+ info = self.get_trainable_params()
110
+ print(f"[MINDIArchitecture] LoRA applied (r={r}, alpha={lora_alpha})")
111
+ print(f" Trainable: {info['trainable']:>14,} ({info['trainable_pct']:.2f}%)")
112
+ print(f" Frozen: {info['frozen']:>14,}")
113
+ print(f" Total: {info['total']:>14,}")
114
+
115
+ return self.peft_model
116
+
117
+ def get_trainable_params(self) -> dict:
118
+ """
119
+ Count trainable, frozen, and total parameters.
120
+
121
+ Returns:
122
+ Dictionary with 'trainable', 'frozen', 'total', 'trainable_pct'.
123
+ """
124
+ model = self.peft_model or self.model
125
+ if model is None:
126
+ return {"trainable": 0, "frozen": 0, "total": 0, "trainable_pct": 0.0}
127
+
128
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
129
+ total = sum(p.numel() for p in model.parameters())
130
+ frozen = total - trainable
131
+ pct = 100.0 * trainable / total if total > 0 else 0.0
132
+
133
+ return {
134
+ "trainable": trainable,
135
+ "frozen": frozen,
136
+ "total": total,
137
+ "trainable_pct": round(pct, 4),
138
+ }
139
+
140
+ def print_model_info(self) -> None:
141
+ """Print detailed model architecture and parameter information."""
142
+ model = self.peft_model or self.model
143
+ if model is None:
144
+ print("[MINDIArchitecture] No model loaded.")
145
+ return
146
+
147
+ info = self.get_trainable_params()
148
+ print()
149
+ print("=" * 60)
150
+ print(" MINDI 1.5 β€” Model Architecture Info")
151
+ print("=" * 60)
152
+ print(f" Base model: {self.model_name}")
153
+ print(f" Device: {self.device}")
154
+ print(f" Dtype: {self.torch_dtype}")
155
+ print(f" LoRA active: {self.peft_model is not None}")
156
+ print(f" Total params: {self._fmt_params(info['total'])}")
157
+ print(f" Trainable: {self._fmt_params(info['trainable'])} "
158
+ f"({info['trainable_pct']:.2f}%)")
159
+ print(f" Frozen: {self._fmt_params(info['frozen'])}")
160
+
161
+ if self.peft_model is not None:
162
+ config = self.peft_model.peft_config.get("default")
163
+ if config is not None:
164
+ print(f" LoRA rank: {config.r}")
165
+ print(f" LoRA alpha: {config.lora_alpha}")
166
+ print(f" LoRA dropout: {config.lora_dropout}")
167
+ print(f" Target modules: {config.target_modules}")
168
+ print("=" * 60)
169
+ print()
170
+
171
+ def save_lora(self, path: Optional[Path] = None) -> Path:
172
+ """
173
+ Save LoRA adapter weights to disk.
174
+
175
+ Args:
176
+ path: Directory to save to. Defaults to checkpoints/lora.
177
+
178
+ Returns:
179
+ Path where weights were saved.
180
+ """
181
+ if self.peft_model is None:
182
+ raise RuntimeError("No LoRA adapter to save. Call apply_lora() first.")
183
+
184
+ save_path = Path(path) if path else Path("./checkpoints/lora")
185
+ save_path.mkdir(parents=True, exist_ok=True)
186
+ self.peft_model.save_pretrained(str(save_path))
187
+ print(f"[MINDIArchitecture] LoRA saved to {save_path}")
188
+ return save_path
189
+
190
+ def load_lora(self, path: Path) -> PeftModel:
191
+ """
192
+ Load LoRA adapter weights from disk.
193
+
194
+ Args:
195
+ path: Directory containing saved adapter weights.
196
+
197
+ Returns:
198
+ The PEFT-wrapped model with loaded adapter.
199
+ """
200
+ path = Path(path)
201
+ if not path.exists():
202
+ raise FileNotFoundError(f"LoRA adapter not found: {path}")
203
+ if self.model is None:
204
+ raise RuntimeError("Base model not loaded.")
205
+
206
+ self.peft_model = PeftModel.from_pretrained(
207
+ self.model, str(path)
208
+ )
209
+ print(f"[MINDIArchitecture] LoRA loaded from {path}")
210
+ return self.peft_model
211
+
212
+ def resize_embeddings(self, new_vocab_size: int) -> None:
213
+ """Resize model embeddings for new special tokens."""
214
+ model = self.peft_model or self.model
215
+ if model is None:
216
+ raise RuntimeError("No model loaded.")
217
+ old_size = model.get_input_embeddings().weight.shape[0]
218
+ if new_vocab_size != old_size:
219
+ model.resize_token_embeddings(new_vocab_size)
220
+ print(f"[MINDIArchitecture] Resized embeddings: {old_size} β†’ {new_vocab_size}")
221
+
222
+ def get_model(self) -> AutoModelForCausalLM | PeftModel:
223
+ """Return the active model (PEFT if LoRA applied, else base)."""
224
+ model = self.peft_model or self.model
225
+ if model is None:
226
+ raise RuntimeError("No model loaded.")
227
+ return model
228
+
229
+ # ── helpers ───────────────────────────────────────────────────
230
+ def _total_params(self) -> int:
231
+ model = self.peft_model or self.model
232
+ if model is None:
233
+ return 0
234
+ return sum(p.numel() for p in model.parameters())
235
+
236
+ @staticmethod
237
+ def _fmt_params(n: int) -> str:
238
+ if n >= 1_000_000_000:
239
+ return f"{n / 1_000_000_000:.2f}B"
240
+ if n >= 1_000_000:
241
+ return f"{n / 1_000_000:.2f}M"
242
+ if n >= 1_000:
243
+ return f"{n / 1_000:.1f}K"
244
+ return str(n)
245
+
246
+
247
+ # ── Test block ────────────────────────────────────────────────────────
248
+ if __name__ == "__main__":
249
+ print("=" * 60)
250
+ print(" MINDI 1.5 β€” Architecture Test")
251
+ print("=" * 60)
252
+ print()
253
+
254
+ # 1. Load base model
255
+ arch = MINDIArchitecture(
256
+ model_name="Qwen/Qwen2.5-Coder-7B-Instruct",
257
+ )
258
+
259
+ # 2. Apply LoRA
260
+ peft_model = arch.apply_lora(
261
+ r=64,
262
+ lora_alpha=128,
263
+ lora_dropout=0.05,
264
+ )
265
+
266
+ # 3. Print full info
267
+ arch.print_model_info()
268
+
269
+ # 4. Verify trainable params
270
+ info = arch.get_trainable_params()
271
+ assert info["trainable"] > 0, "No trainable parameters!"
272
+ assert info["frozen"] > info["trainable"], "More trainable than frozen β€” LoRA may not be applied!"
273
+
274
+ # 5. Verify LoRA modules exist
275
+ lora_modules = [name for name, _ in peft_model.named_parameters() if "lora_" in name]
276
+ print(f" LoRA modules found: {len(lora_modules)}")
277
+ assert len(lora_modules) > 0, "No LoRA modules found!"
278
+
279
+ # 6. Quick forward pass test (small input)
280
+ print("\n Running forward pass test ...")
281
+ test_input = arch.tokenizer("Hello MINDI!", return_tensors="pt")
282
+ test_input = {k: v.to(arch.device) for k, v in test_input.items()}
283
+ with torch.no_grad():
284
+ output = peft_model(**test_input)
285
+ print(f" Output logits shape: {output.logits.shape}")
286
+ print(f" Loss: {output.loss}")
287
+
288
+ print("\n βœ“ All architecture tests passed!")
289
+ print("=" * 60)
src/model/fusion_layer.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MINDI 1.5 Vision-Coder β€” Vision-Language Fusion Layer
3
+
4
+ Prepends projected visual tokens (256 Γ— 4096) to text token embeddings
5
+ and extends the attention mask accordingly. Uses Linear + LayerNorm
6
+ for the visual projection gate.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from typing import Optional
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+
16
+
17
+ class VisionLanguageFusion(nn.Module):
18
+ """
19
+ Fuses visual and text embeddings by prepending visual tokens.
20
+
21
+ Pipeline:
22
+ 1. visual_tokens (batch, 256, 4096) β†’ Linear β†’ LayerNorm
23
+ 2. Prepend to text_embeds (batch, seq_len, 4096)
24
+ 3. Extend attention_mask to cover the extra 256 visual positions
25
+
26
+ All trainable parameters live in the gate projection + LayerNorm.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ hidden_size: int = 4096,
32
+ num_visual_tokens: int = 256,
33
+ ) -> None:
34
+ """
35
+ Initialize the fusion layer.
36
+
37
+ Args:
38
+ hidden_size: Dimension of both visual and text embeddings (must match).
39
+ num_visual_tokens: Number of visual tokens prepended (default 256).
40
+ """
41
+ super().__init__()
42
+ self.hidden_size = hidden_size
43
+ self.num_visual_tokens = num_visual_tokens
44
+
45
+ # Gate projection: Linear + LayerNorm to align visual features
46
+ self.gate_proj = nn.Linear(hidden_size, hidden_size)
47
+ self.layer_norm = nn.LayerNorm(hidden_size)
48
+
49
+ def forward(
50
+ self,
51
+ text_embeds: torch.Tensor,
52
+ visual_tokens: Optional[torch.Tensor] = None,
53
+ attention_mask: Optional[torch.Tensor] = None,
54
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
55
+ """
56
+ Fuse visual tokens into text embeddings.
57
+
58
+ Args:
59
+ text_embeds: Text token embeddings (batch, seq_len, hidden_size).
60
+ visual_tokens: Projected visual tokens (batch, 256, hidden_size), or None
61
+ for text-only inputs.
62
+ attention_mask: Text attention mask (batch, seq_len), or None.
63
+
64
+ Returns:
65
+ fused_embeds: (batch, 256 + seq_len, hidden_size) if visual, else unchanged.
66
+ fused_mask: Extended attention mask, or None if input mask was None.
67
+ """
68
+ # Text-only path β€” no vision tokens to fuse
69
+ if visual_tokens is None:
70
+ return text_embeds, attention_mask
71
+
72
+ batch_size = text_embeds.shape[0]
73
+ v_batch = visual_tokens.shape[0]
74
+
75
+ # Handle batch size mismatch (single image broadcast to batch)
76
+ if v_batch == 1 and batch_size > 1:
77
+ visual_tokens = visual_tokens.expand(batch_size, -1, -1)
78
+
79
+ # Gate projection + LayerNorm
80
+ gated_visual = self.gate_proj(visual_tokens) # (batch, 256, hidden_size)
81
+ gated_visual = self.layer_norm(gated_visual) # (batch, 256, hidden_size)
82
+
83
+ # Prepend visual tokens to text embeddings
84
+ fused_embeds = torch.cat([gated_visual, text_embeds], dim=1)
85
+
86
+ # Extend attention mask
87
+ fused_mask = self._extend_attention_mask(attention_mask, batch_size, text_embeds.device)
88
+
89
+ return fused_embeds, fused_mask
90
+
91
+ def _extend_attention_mask(
92
+ self,
93
+ attention_mask: Optional[torch.Tensor],
94
+ batch_size: int,
95
+ device: torch.device,
96
+ ) -> Optional[torch.Tensor]:
97
+ """
98
+ Extend attention mask to include visual token positions (all attended).
99
+
100
+ Args:
101
+ attention_mask: Original text mask (batch, seq_len) or None.
102
+ batch_size: Current batch size.
103
+ device: Target device.
104
+
105
+ Returns:
106
+ Extended mask (batch, 256 + seq_len) or None.
107
+ """
108
+ if attention_mask is None:
109
+ return None
110
+
111
+ # Visual tokens are always fully attended
112
+ visual_mask = torch.ones(
113
+ batch_size,
114
+ self.num_visual_tokens,
115
+ dtype=attention_mask.dtype,
116
+ device=device,
117
+ )
118
+ return torch.cat([visual_mask, attention_mask], dim=1)
119
+
120
+ def get_trainable_params(self) -> dict:
121
+ """
122
+ Count trainable parameters in the fusion layer.
123
+
124
+ Returns:
125
+ Dictionary with 'trainable', 'total', and 'trainable_pct'.
126
+ """
127
+ trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
128
+ total = sum(p.numel() for p in self.parameters())
129
+ pct = 100.0 * trainable / total if total > 0 else 0.0
130
+ return {
131
+ "trainable": trainable,
132
+ "total": total,
133
+ "trainable_pct": round(pct, 4),
134
+ }
135
+
136
+ def extra_repr(self) -> str:
137
+ return (
138
+ f"hidden_size={self.hidden_size}, "
139
+ f"num_visual_tokens={self.num_visual_tokens}"
140
+ )
141
+
142
+
143
+ # ── Test block ────────────────────────────────────────────────────────
144
+ if __name__ == "__main__":
145
+ print("=" * 60)
146
+ print(" MINDI 1.5 β€” Fusion Layer Test")
147
+ print("=" * 60)
148
+ print()
149
+
150
+ BATCH = 2
151
+ SEQ_LEN = 128
152
+ HIDDEN = 4096
153
+ N_VIS = 256
154
+
155
+ fusion = VisionLanguageFusion(hidden_size=HIDDEN, num_visual_tokens=N_VIS)
156
+ print(f" Fusion layer:\n {fusion}\n")
157
+
158
+ # ── Test 1: Vision + Text fusion ─────────────────────────────
159
+ print(" Test 1: Vision + Text fusion")
160
+ text_embeds = torch.randn(BATCH, SEQ_LEN, HIDDEN)
161
+ visual_tokens = torch.randn(BATCH, N_VIS, HIDDEN)
162
+ attention_mask = torch.ones(BATCH, SEQ_LEN, dtype=torch.long)
163
+
164
+ fused_embeds, fused_mask = fusion(text_embeds, visual_tokens, attention_mask)
165
+
166
+ expected_seq = N_VIS + SEQ_LEN # 256 + 128 = 384
167
+ assert fused_embeds.shape == (BATCH, expected_seq, HIDDEN), \
168
+ f"Expected ({BATCH}, {expected_seq}, {HIDDEN}), got {fused_embeds.shape}"
169
+ assert fused_mask is not None and fused_mask.shape == (BATCH, expected_seq), \
170
+ f"Expected mask ({BATCH}, {expected_seq}), got {fused_mask.shape}"
171
+ print(f" fused_embeds: {fused_embeds.shape} βœ“")
172
+ print(f" fused_mask: {fused_mask.shape} βœ“")
173
+
174
+ # ── Test 2: Text-only (no vision) ────────────────────────────
175
+ print("\n Test 2: Text-only (no vision)")
176
+ text_only, mask_only = fusion(text_embeds, None, attention_mask)
177
+ assert text_only.shape == (BATCH, SEQ_LEN, HIDDEN)
178
+ assert mask_only is not None and mask_only.shape == (BATCH, SEQ_LEN)
179
+ print(f" text_only: {text_only.shape} βœ“")
180
+ print(f" mask_only: {mask_only.shape} βœ“")
181
+
182
+ # ── Test 3: No attention mask ────────────────────────────────
183
+ print("\n Test 3: Vision fusion without attention mask")
184
+ fused_no_mask, none_mask = fusion(text_embeds, visual_tokens, None)
185
+ assert fused_no_mask.shape == (BATCH, expected_seq, HIDDEN)
186
+ assert none_mask is None
187
+ print(f" fused_embeds: {fused_no_mask.shape} βœ“")
188
+ print(f" fused_mask: None βœ“")
189
+
190
+ # ── Test 4: Single-image broadcast ───────────────────────────
191
+ print("\n Test 4: Single-image broadcast to batch")
192
+ single_visual = torch.randn(1, N_VIS, HIDDEN)
193
+ fused_bc, mask_bc = fusion(text_embeds, single_visual, attention_mask)
194
+ assert fused_bc.shape == (BATCH, expected_seq, HIDDEN)
195
+ print(f" fused_embeds: {fused_bc.shape} βœ“ (broadcast 1 β†’ {BATCH})")
196
+
197
+ # ── Test 5: Trainable params ─────────────────────────────────
198
+ print("\n Test 5: Parameter counts")
199
+ info = fusion.get_trainable_params()
200
+ # gate_proj: 4096*4096 + 4096 = 16,781,312
201
+ # layer_norm: 4096 + 4096 = 8,192
202
+ expected_params = HIDDEN * HIDDEN + HIDDEN + HIDDEN + HIDDEN # Linear(w+b) + LN(w+b)
203
+ assert info["trainable"] == expected_params, \
204
+ f"Expected {expected_params}, got {info['trainable']}"
205
+ print(f" Trainable: {info['trainable']:,}")
206
+ print(f" Total: {info['total']:,}")
207
+ print(f" Pct: {info['trainable_pct']}%")
208
+
209
+ # ── Test 6: Gradient flow ────────────────────────────────────
210
+ print("\n Test 6: Gradient flow through fusion")
211
+ fusion.zero_grad()
212
+ fused_embeds, _ = fusion(text_embeds, visual_tokens, attention_mask)
213
+ loss = fused_embeds.sum()
214
+ loss.backward()
215
+ assert fusion.gate_proj.weight.grad is not None, "No gradient on gate_proj!"
216
+ assert fusion.layer_norm.weight.grad is not None, "No gradient on layer_norm!"
217
+ print(" gate_proj gradient: βœ“")
218
+ print(" layer_norm gradient: βœ“")
219
+
220
+ print("\n βœ“ All fusion layer tests passed!")
221
+ print("=" * 60)
src/model/mindi_model.py ADDED
@@ -0,0 +1,620 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MINDI 1.5 Vision-Coder β€” Complete Model
3
+
4
+ Combines MINDIArchitecture (Qwen2.5-Coder + LoRA), VisionEncoder (CLIP ViT-L/14),
5
+ and VisionLanguageFusion into a single MINDI15 class with forward(), generate(),
6
+ parse_output(), save(), and load() methods.
7
+
8
+ Uses the MINDI custom tokenizer (data/tokenizer/mindi_tokenizer/) with 22 special
9
+ tokens for agentic code generation capabilities.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import re
15
+ from pathlib import Path
16
+ from typing import Optional
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ from PIL import Image
21
+ from transformers import AutoTokenizer, PreTrainedTokenizerFast
22
+
23
+ from src.model.architecture import MINDIArchitecture
24
+ from src.model.fusion_layer import VisionLanguageFusion
25
+ from src.model.vision_encoder import VisionEncoder
26
+
27
+ # ── MINDI special token pairs ────────────────────────────────────────
28
+ MINDI_SECTION_TOKENS: dict[str, tuple[str, str]] = {
29
+ "thinking": ("<|think_start|>", "<|think_end|>"),
30
+ "file": ("<|file_start|>", "<|file_end|>"),
31
+ "code": ("<|code_start|>", "<|code_end|>"),
32
+ "critique": ("<|critique_start|>", "<|critique_end|>"),
33
+ "suggest": ("<|suggest_start|>", "<|suggest_end|>"),
34
+ "search": ("<|search_start|>", "<|search_end|>"),
35
+ "error": ("<|error_start|>", "<|error_end|>"),
36
+ "fix": ("<|fix_start|>", "<|fix_end|>"),
37
+ }
38
+
39
+ # Project root (resolved relative to this file)
40
+ PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
41
+ DEFAULT_TOKENIZER_PATH = PROJECT_ROOT / "data" / "tokenizer" / "mindi_tokenizer"
42
+
43
+
44
+ class MINDI15(nn.Module):
45
+ """
46
+ MINDI 1.5 Vision-Coder β€” complete multimodal coding model.
47
+
48
+ Components:
49
+ - architecture: Qwen2.5-Coder-7B-Instruct + LoRA
50
+ - vision_encoder: CLIP ViT-L/14 (frozen) β†’ 256 tokens Γ— 4096
51
+ - fusion: Linear + LayerNorm prepend fusion
52
+ - tokenizer: MINDI custom tokenizer with 22 special tokens
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ model_name: str = "Qwen/Qwen2.5-Coder-7B-Instruct",
58
+ clip_model: str = "openai/clip-vit-large-patch14",
59
+ hidden_size: int = 4096,
60
+ num_visual_tokens: int = 256,
61
+ tokenizer_path: Optional[Path] = None,
62
+ device: Optional[str] = None,
63
+ torch_dtype: torch.dtype = torch.bfloat16,
64
+ cache_dir: Optional[Path] = None,
65
+ ) -> None:
66
+ """
67
+ Initialize MINDI 1.5 with all components.
68
+
69
+ Args:
70
+ model_name: HuggingFace base LLM identifier.
71
+ clip_model: HuggingFace CLIP vision model identifier.
72
+ hidden_size: LLM hidden dimension (must match Qwen config).
73
+ num_visual_tokens: Number of visual tokens from CLIP (256).
74
+ tokenizer_path: Path to MINDI custom tokenizer directory.
75
+ device: Target device ('cuda', 'cpu', or None for auto).
76
+ torch_dtype: Data type for model weights.
77
+ cache_dir: Base directory for model weight caches.
78
+ """
79
+ super().__init__()
80
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
81
+ self.hidden_size = hidden_size
82
+ self.num_visual_tokens = num_visual_tokens
83
+ self.torch_dtype = torch_dtype
84
+
85
+ cache_base = Path(cache_dir) if cache_dir else PROJECT_ROOT / "checkpoints"
86
+
87
+ print("=" * 60)
88
+ print(" MINDI 1.5 Vision-Coder β€” Initializing")
89
+ print("=" * 60)
90
+
91
+ # 1. Load MINDI custom tokenizer (NOT the base Qwen tokenizer)
92
+ tok_path = Path(tokenizer_path) if tokenizer_path else DEFAULT_TOKENIZER_PATH
93
+ print(f"\n[MINDI15] Loading MINDI tokenizer from {tok_path} ...")
94
+ self.tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained(
95
+ str(tok_path),
96
+ trust_remote_code=True,
97
+ )
98
+ print(f" Vocab size: {len(self.tokenizer)}")
99
+
100
+ # 2. LLM backbone with LoRA
101
+ self.architecture = MINDIArchitecture(
102
+ model_name=model_name,
103
+ device=self.device,
104
+ cache_dir=cache_base / "base",
105
+ torch_dtype=torch_dtype,
106
+ )
107
+
108
+ # Resize embeddings to match MINDI tokenizer (includes 22 special tokens)
109
+ self.architecture.resize_embeddings(len(self.tokenizer))
110
+
111
+ # Apply LoRA
112
+ self.architecture.apply_lora()
113
+
114
+ # 3. Vision encoder (frozen CLIP + trainable projection)
115
+ self.vision_encoder = VisionEncoder(
116
+ model_name=clip_model,
117
+ llm_hidden_size=hidden_size,
118
+ device=self.device,
119
+ cache_dir=cache_base / "vision",
120
+ )
121
+
122
+ # 4. Fusion layer
123
+ self.fusion = VisionLanguageFusion(
124
+ hidden_size=hidden_size,
125
+ num_visual_tokens=num_visual_tokens,
126
+ )
127
+ self.fusion.to(self.device)
128
+
129
+ # Cache special token IDs
130
+ self._special_ids: dict[str, int] = {}
131
+ for section, (start_tok, end_tok) in MINDI_SECTION_TOKENS.items():
132
+ sid = self.tokenizer.convert_tokens_to_ids(start_tok)
133
+ eid = self.tokenizer.convert_tokens_to_ids(end_tok)
134
+ self._special_ids[f"{section}_start"] = sid
135
+ self._special_ids[f"{section}_end"] = eid
136
+
137
+ self._print_summary()
138
+
139
+ def _print_summary(self) -> None:
140
+ """Print initialization summary."""
141
+ llm_info = self.architecture.get_trainable_params()
142
+ vis_info = {
143
+ "trainable": sum(p.numel() for p in self.vision_encoder.parameters() if p.requires_grad),
144
+ "total": sum(p.numel() for p in self.vision_encoder.parameters()),
145
+ }
146
+ fus_info = self.fusion.get_trainable_params()
147
+
148
+ total_trainable = llm_info["trainable"] + vis_info["trainable"] + fus_info["trainable"]
149
+ total_all = llm_info["total"] + vis_info["total"] + fus_info["total"]
150
+
151
+ print()
152
+ print("=" * 60)
153
+ print(" MINDI 1.5 β€” Initialization Complete")
154
+ print("=" * 60)
155
+ print(f" LLM trainable (LoRA): {llm_info['trainable']:>14,}")
156
+ print(f" Vision trainable: {vis_info['trainable']:>14,}")
157
+ print(f" Fusion trainable: {fus_info['trainable']:>14,}")
158
+ print(f" ─────────────────────────────────────")
159
+ print(f" Total trainable: {total_trainable:>14,}")
160
+ print(f" Total params: {total_all:>14,}")
161
+ print(f" Tokenizer vocab: {len(self.tokenizer):>14,}")
162
+ print("=" * 60)
163
+ print()
164
+
165
+ # ── Forward pass ──────────────────────────────────────────────
166
+
167
+ def forward(
168
+ self,
169
+ input_ids: torch.Tensor,
170
+ attention_mask: Optional[torch.Tensor] = None,
171
+ labels: Optional[torch.Tensor] = None,
172
+ image: Optional[Image.Image] = None,
173
+ ) -> dict:
174
+ """
175
+ Forward pass with optional vision input.
176
+
177
+ Args:
178
+ input_ids: Token IDs (batch, seq_len).
179
+ attention_mask: Attention mask (batch, seq_len).
180
+ labels: Target token IDs for loss computation (batch, seq_len).
181
+ image: Optional PIL image for multimodal input.
182
+
183
+ Returns:
184
+ Dict with 'loss', 'logits', and optionally 'visual_tokens'.
185
+ """
186
+ model = self.architecture.get_model()
187
+
188
+ # Get text embeddings from the LLM's embedding layer
189
+ text_embeds = model.get_input_embeddings()(input_ids)
190
+
191
+ # Encode vision if image provided
192
+ visual_tokens = None
193
+ if image is not None:
194
+ visual_tokens = self.vision_encoder.encode_image(image)
195
+
196
+ # Fuse vision + text
197
+ fused_embeds, fused_mask = self.fusion(text_embeds, visual_tokens, attention_mask)
198
+
199
+ # Extend labels if vision tokens were prepended
200
+ if visual_tokens is not None and labels is not None:
201
+ batch_size = labels.shape[0]
202
+ # -100 = ignore index for cross-entropy on visual positions
203
+ visual_labels = torch.full(
204
+ (batch_size, self.num_visual_tokens),
205
+ fill_value=-100,
206
+ dtype=labels.dtype,
207
+ device=labels.device,
208
+ )
209
+ labels = torch.cat([visual_labels, labels], dim=1)
210
+
211
+ # Forward through LLM with embeddings (bypass tokenization)
212
+ outputs = model(
213
+ inputs_embeds=fused_embeds,
214
+ attention_mask=fused_mask,
215
+ labels=labels,
216
+ )
217
+
218
+ result = {
219
+ "loss": outputs.loss,
220
+ "logits": outputs.logits,
221
+ }
222
+ if visual_tokens is not None:
223
+ result["visual_tokens"] = visual_tokens
224
+
225
+ return result
226
+
227
+ # ── Generation ────────────────────────────────────────────────
228
+
229
+ @torch.no_grad()
230
+ def generate(
231
+ self,
232
+ prompt: str,
233
+ image: Optional[Image.Image] = None,
234
+ max_new_tokens: int = 2048,
235
+ temperature: float = 0.7,
236
+ top_p: float = 0.9,
237
+ top_k: int = 50,
238
+ do_sample: bool = True,
239
+ repetition_penalty: float = 1.1,
240
+ ) -> str:
241
+ """
242
+ Generate text from a prompt, optionally conditioned on an image.
243
+
244
+ Uses the MINDI custom tokenizer (with special tokens) for both
245
+ encoding the prompt and decoding the output.
246
+
247
+ Args:
248
+ prompt: Input text prompt.
249
+ image: Optional PIL image for multimodal generation.
250
+ max_new_tokens: Maximum tokens to generate.
251
+ temperature: Sampling temperature.
252
+ top_p: Nucleus sampling threshold.
253
+ top_k: Top-k sampling threshold.
254
+ do_sample: Whether to sample (False = greedy).
255
+ repetition_penalty: Penalty for repeated tokens.
256
+
257
+ Returns:
258
+ Generated text string (decoded with MINDI tokenizer).
259
+ """
260
+ model = self.architecture.get_model()
261
+ model.eval()
262
+
263
+ # Tokenize with MINDI tokenizer
264
+ inputs = self.tokenizer(prompt, return_tensors="pt")
265
+ input_ids = inputs["input_ids"].to(self.device)
266
+ attention_mask = inputs["attention_mask"].to(self.device)
267
+
268
+ # If image provided, build fused embeddings
269
+ if image is not None:
270
+ text_embeds = model.get_input_embeddings()(input_ids)
271
+ visual_tokens = self.vision_encoder.encode_image(image)
272
+ fused_embeds, fused_mask = self.fusion(text_embeds, visual_tokens, attention_mask)
273
+
274
+ output_ids = model.generate(
275
+ inputs_embeds=fused_embeds,
276
+ attention_mask=fused_mask,
277
+ max_new_tokens=max_new_tokens,
278
+ temperature=temperature,
279
+ top_p=top_p,
280
+ top_k=top_k,
281
+ do_sample=do_sample,
282
+ repetition_penalty=repetition_penalty,
283
+ pad_token_id=self.tokenizer.pad_token_id or self.tokenizer.eos_token_id,
284
+ )
285
+ else:
286
+ # Text-only generation (direct input_ids)
287
+ output_ids = model.generate(
288
+ input_ids=input_ids,
289
+ attention_mask=attention_mask,
290
+ max_new_tokens=max_new_tokens,
291
+ temperature=temperature,
292
+ top_p=top_p,
293
+ top_k=top_k,
294
+ do_sample=do_sample,
295
+ repetition_penalty=repetition_penalty,
296
+ pad_token_id=self.tokenizer.pad_token_id or self.tokenizer.eos_token_id,
297
+ )
298
+
299
+ # Decode only the newly generated tokens
300
+ generated_ids = output_ids[:, input_ids.shape[1]:]
301
+ text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=False)
302
+ return text.strip()
303
+
304
+ # ── Output parsing ────────────────────────────────────────────
305
+
306
+ @staticmethod
307
+ def parse_output(text: str) -> dict[str, list[str]]:
308
+ """
309
+ Parse generated text and extract ALL MINDI special-token sections.
310
+
311
+ Extracts content between each pair of special tokens:
312
+ <|think_start|> ... <|think_end|> β†’ "thinking"
313
+ <|file_start|> ... <|file_end|> β†’ "file"
314
+ <|code_start|> ... <|code_end|> β†’ "code"
315
+ <|critique_start|> ... <|critique_end|> β†’ "critique"
316
+ <|suggest_start|> ... <|suggest_end|> β†’ "suggest"
317
+ <|search_start|> ... <|search_end|> β†’ "search"
318
+ <|error_start|> ... <|error_end|> β†’ "error"
319
+ <|fix_start|> ... <|fix_end|> β†’ "fix"
320
+
321
+ Each section may appear multiple times; all occurrences are captured.
322
+
323
+ Args:
324
+ text: Raw generated text potentially containing special tokens.
325
+
326
+ Returns:
327
+ Dict mapping section name β†’ list of extracted content strings.
328
+ Empty list if section not found. Also includes "raw" with full text.
329
+ """
330
+ result: dict[str, list[str]] = {"raw": [text]}
331
+
332
+ for section, (start_tok, end_tok) in MINDI_SECTION_TOKENS.items():
333
+ # Escape the pipe characters for regex
334
+ pattern = re.escape(start_tok) + r"(.*?)" + re.escape(end_tok)
335
+ matches = re.findall(pattern, text, flags=re.DOTALL)
336
+ result[section] = [m.strip() for m in matches]
337
+
338
+ return result
339
+
340
+ # ── Phase control (for 3-phase training) ──────────────────────
341
+
342
+ def set_trainable_components(
343
+ self,
344
+ lora: bool = False,
345
+ vision_projection: bool = False,
346
+ fusion: bool = False,
347
+ ) -> dict[str, int]:
348
+ """
349
+ Enable/disable training for specific components.
350
+
351
+ Used by the trainer to implement 3-phase training:
352
+ Phase 1: lora=True, vision_projection=False, fusion=False
353
+ Phase 2: lora=False, vision_projection=True, fusion=True
354
+ Phase 3: lora=True, vision_projection=True, fusion=True
355
+
356
+ Args:
357
+ lora: Whether LoRA adapter parameters should be trainable.
358
+ vision_projection: Whether the vision projection layer should train.
359
+ fusion: Whether the fusion layer should be trainable.
360
+
361
+ Returns:
362
+ Dict with trainable param counts per component.
363
+ """
364
+ counts = {}
365
+
366
+ # LoRA parameters
367
+ peft_model = self.architecture.peft_model
368
+ if peft_model is not None:
369
+ for name, param in peft_model.named_parameters():
370
+ if "lora_" in name:
371
+ param.requires_grad = lora
372
+ counts["lora"] = sum(
373
+ p.numel() for n, p in (peft_model or self.architecture.model).named_parameters()
374
+ if "lora_" in n and p.requires_grad
375
+ )
376
+
377
+ # Vision projection
378
+ for param in self.vision_encoder.projection.parameters():
379
+ param.requires_grad = vision_projection
380
+ counts["vision_projection"] = sum(
381
+ p.numel() for p in self.vision_encoder.projection.parameters() if p.requires_grad
382
+ )
383
+
384
+ # Fusion layer
385
+ for param in self.fusion.parameters():
386
+ param.requires_grad = fusion
387
+ counts["fusion"] = sum(
388
+ p.numel() for p in self.fusion.parameters() if p.requires_grad
389
+ )
390
+
391
+ counts["total_trainable"] = counts["lora"] + counts["vision_projection"] + counts["fusion"]
392
+
393
+ print(f"[MINDI15] Trainable: LoRA={counts['lora']:,} | "
394
+ f"VisionProj={counts['vision_projection']:,} | "
395
+ f"Fusion={counts['fusion']:,} | "
396
+ f"Total={counts['total_trainable']:,}")
397
+
398
+ return counts
399
+
400
+ # ── Save / Load ───────────────────────────────────────────────
401
+
402
+ def save(self, save_dir: Optional[Path] = None) -> Path:
403
+ """
404
+ Save all trainable weights (LoRA + vision projection + fusion).
405
+
406
+ Args:
407
+ save_dir: Root directory for saving. Defaults to checkpoints/mindi15.
408
+
409
+ Returns:
410
+ Path to save directory.
411
+ """
412
+ save_path = Path(save_dir) if save_dir else PROJECT_ROOT / "checkpoints" / "mindi15"
413
+ save_path.mkdir(parents=True, exist_ok=True)
414
+
415
+ # LoRA adapter
416
+ self.architecture.save_lora(save_path / "lora")
417
+
418
+ # Vision projection
419
+ self.vision_encoder.save_projection(save_path / "vision")
420
+
421
+ # Fusion layer
422
+ fusion_path = save_path / "fusion"
423
+ fusion_path.mkdir(parents=True, exist_ok=True)
424
+ torch.save(self.fusion.state_dict(), fusion_path / "fusion.pt")
425
+
426
+ print(f"[MINDI15] All weights saved to {save_path}")
427
+ return save_path
428
+
429
+ def load(self, load_dir: Path) -> None:
430
+ """
431
+ Load all trainable weights (LoRA + vision projection + fusion).
432
+
433
+ Args:
434
+ load_dir: Root directory containing saved weights.
435
+ """
436
+ load_path = Path(load_dir)
437
+ if not load_path.exists():
438
+ raise FileNotFoundError(f"Checkpoint not found: {load_path}")
439
+
440
+ # LoRA adapter
441
+ lora_path = load_path / "lora"
442
+ if lora_path.exists():
443
+ self.architecture.load_lora(lora_path)
444
+
445
+ # Vision projection
446
+ vision_path = load_path / "vision"
447
+ if vision_path.exists():
448
+ self.vision_encoder.load_projection(vision_path)
449
+
450
+ # Fusion layer
451
+ fusion_file = load_path / "fusion" / "fusion.pt"
452
+ if fusion_file.exists():
453
+ state_dict = torch.load(fusion_file, map_location=self.device, weights_only=True)
454
+ self.fusion.load_state_dict(state_dict)
455
+ print(f"[MINDI15] Fusion loaded from {fusion_file.parent}")
456
+
457
+ print(f"[MINDI15] All weights loaded from {load_path}")
458
+
459
+ # ── Utilities ─────────────────────────────────────────────────
460
+
461
+ def get_all_trainable_params(self) -> dict:
462
+ """Get combined trainable parameter counts across all components."""
463
+ llm = self.architecture.get_trainable_params()
464
+ vis_trainable = sum(
465
+ p.numel() for p in self.vision_encoder.parameters() if p.requires_grad
466
+ )
467
+ fus = self.fusion.get_trainable_params()
468
+
469
+ total_trainable = llm["trainable"] + vis_trainable + fus["trainable"]
470
+ total_all = llm["total"] + sum(p.numel() for p in self.vision_encoder.parameters()) + fus["total"]
471
+
472
+ return {
473
+ "llm_trainable": llm["trainable"],
474
+ "llm_total": llm["total"],
475
+ "vision_trainable": vis_trainable,
476
+ "fusion_trainable": fus["trainable"],
477
+ "total_trainable": total_trainable,
478
+ "total_params": total_all,
479
+ "trainable_pct": round(100.0 * total_trainable / total_all, 4) if total_all > 0 else 0.0,
480
+ }
481
+
482
+ def print_info(self) -> None:
483
+ """Print complete model information."""
484
+ self.architecture.print_model_info()
485
+ info = self.get_all_trainable_params()
486
+ print(" MINDI 1.5 Combined Trainable Parameters:")
487
+ print(f" LLM (LoRA): {info['llm_trainable']:>14,}")
488
+ print(f" Vision proj: {info['vision_trainable']:>14,}")
489
+ print(f" Fusion: {info['fusion_trainable']:>14,}")
490
+ print(f" Total trainable: {info['total_trainable']:>14,}")
491
+ print(f" Total params: {info['total_params']:>14,}")
492
+ print(f" Trainable %: {info['trainable_pct']:>13.2f}%")
493
+ print()
494
+
495
+
496
+ # ── Test block ────────────────────────────────────────────────────────
497
+ if __name__ == "__main__":
498
+ print("=" * 60)
499
+ print(" MINDI 1.5 β€” Complete Model Test")
500
+ print("=" * 60)
501
+ print()
502
+
503
+ # ── Test 1: parse_output (no GPU needed) ─────────────────────
504
+ print(" Test 1: parse_output()")
505
+ sample_output = (
506
+ "<|think_start|>The user wants a Python function.<|think_end|>"
507
+ "<|file_start|>main.py<|file_end|>"
508
+ "<|code_start|>def hello():\n print('Hello MINDI!')<|code_end|>"
509
+ "<|critique_start|>Missing type hints and docstring.<|critique_end|>"
510
+ "<|suggest_start|>Add return type annotation.<|suggest_end|>"
511
+ "<|search_start|>python type hints best practices<|search_end|>"
512
+ "<|error_start|>NameError: name 'x' is not defined<|error_end|>"
513
+ "<|fix_start|>Add x = 0 before the loop.<|fix_end|>"
514
+ "<|think_start|>Let me also add error handling.<|think_end|>"
515
+ )
516
+
517
+ parsed = MINDI15.parse_output(sample_output)
518
+
519
+ assert len(parsed["thinking"]) == 2, f"Expected 2 thinking sections, got {len(parsed['thinking'])}"
520
+ assert parsed["thinking"][0] == "The user wants a Python function."
521
+ assert parsed["thinking"][1] == "Let me also add error handling."
522
+ assert parsed["file"] == ["main.py"]
523
+ assert parsed["code"] == ["def hello():\n print('Hello MINDI!')"]
524
+ assert parsed["critique"] == ["Missing type hints and docstring."]
525
+ assert parsed["suggest"] == ["Add return type annotation."]
526
+ assert parsed["search"] == ["python type hints best practices"]
527
+ assert parsed["error"] == ["NameError: name 'x' is not defined"]
528
+ assert parsed["fix"] == ["Add x = 0 before the loop."]
529
+ assert "raw" in parsed
530
+ print(" All 8 section types extracted correctly βœ“")
531
+ print(f" Sections found: {[k for k, v in parsed.items() if k != 'raw' and v]}")
532
+
533
+ # ── Test 2: parse_output with missing sections ───────────────
534
+ print("\n Test 2: parse_output() with partial output")
535
+ partial = "<|code_start|>print('hi')<|code_end|>"
536
+ parsed2 = MINDI15.parse_output(partial)
537
+ assert parsed2["code"] == ["print('hi')"]
538
+ assert parsed2["thinking"] == []
539
+ assert parsed2["file"] == []
540
+ assert parsed2["fix"] == []
541
+ print(" Missing sections return empty lists βœ“")
542
+
543
+ # ── Test 3: parse_output with empty input ────────────────────
544
+ print("\n Test 3: parse_output() with empty string")
545
+ parsed3 = MINDI15.parse_output("")
546
+ assert all(v == [] for k, v in parsed3.items() if k != "raw")
547
+ print(" Empty input returns all empty lists βœ“")
548
+
549
+ # ── Test 4: Verify MINDI_SECTION_TOKENS covers all 8 ────────
550
+ print("\n Test 4: Token coverage")
551
+ expected_sections = {"thinking", "file", "code", "critique", "suggest", "search", "error", "fix"}
552
+ assert set(MINDI_SECTION_TOKENS.keys()) == expected_sections
553
+ print(f" All 8 sections defined: {sorted(expected_sections)} βœ“")
554
+
555
+ # ── GPU-dependent tests (skip if no CUDA) ────────────────────
556
+ if torch.cuda.is_available():
557
+ print("\n Test 5: Full model initialization (GPU)")
558
+ model = MINDI15()
559
+ model.print_info()
560
+
561
+ # Test set_trainable_components (Phase 1)
562
+ print("\n Test 6: Phase 1 β€” LoRA only")
563
+ counts = model.set_trainable_components(lora=True, vision_projection=False, fusion=False)
564
+ assert counts["lora"] > 0
565
+ assert counts["vision_projection"] == 0
566
+ assert counts["fusion"] == 0
567
+
568
+ # Test set_trainable_components (Phase 2)
569
+ print("\n Test 7: Phase 2 β€” Vision bridge only")
570
+ counts = model.set_trainable_components(lora=False, vision_projection=True, fusion=True)
571
+ assert counts["lora"] == 0
572
+ assert counts["vision_projection"] > 0
573
+ assert counts["fusion"] > 0
574
+
575
+ # Test set_trainable_components (Phase 3)
576
+ print("\n Test 8: Phase 3 β€” All trainable")
577
+ counts = model.set_trainable_components(lora=True, vision_projection=True, fusion=True)
578
+ assert counts["lora"] > 0
579
+ assert counts["vision_projection"] > 0
580
+ assert counts["fusion"] > 0
581
+
582
+ # Test forward (text only)
583
+ print("\n Test 9: Forward pass (text only)")
584
+ tokens = model.tokenizer("Hello MINDI!", return_tensors="pt")
585
+ input_ids = tokens["input_ids"].to(model.device)
586
+ attn_mask = tokens["attention_mask"].to(model.device)
587
+ result = model(input_ids=input_ids, attention_mask=attn_mask, labels=input_ids)
588
+ assert result["loss"] is not None
589
+ print(f" Loss: {result['loss'].item():.4f}")
590
+ print(f" Logits: {result['logits'].shape}")
591
+
592
+ # Test forward (with image)
593
+ print("\n Test 10: Forward pass (with dummy image)")
594
+ dummy_img = Image.new("RGB", (224, 224), color=(100, 150, 200))
595
+ result_v = model(input_ids=input_ids, attention_mask=attn_mask, labels=input_ids, image=dummy_img)
596
+ assert result_v["loss"] is not None
597
+ assert "visual_tokens" in result_v
598
+ print(f" Loss: {result_v['loss'].item():.4f}")
599
+ print(f" Visual tokens: {result_v['visual_tokens'].shape}")
600
+
601
+ # Test generate (text only)
602
+ print("\n Test 11: Generate (text only, short)")
603
+ output = model.generate("Write a hello world in Python:", max_new_tokens=50)
604
+ print(f" Output: {output[:100]}...")
605
+
606
+ print("\n Test 12: Save/load round-trip")
607
+ import tempfile
608
+ with tempfile.TemporaryDirectory() as tmp:
609
+ model.save(Path(tmp))
610
+ # Verify files exist
611
+ assert (Path(tmp) / "lora").exists()
612
+ assert (Path(tmp) / "vision" / "projection.pt").exists()
613
+ assert (Path(tmp) / "fusion" / "fusion.pt").exists()
614
+ print(" Save βœ“")
615
+ else:
616
+ print("\n [SKIP] GPU tests (no CUDA available)")
617
+ print(" Tests 5-12 require GPU with ~20GB VRAM")
618
+
619
+ print("\n βœ“ All MINDI 1.5 model tests passed!")
620
+ print("=" * 60)
src/model/vision_encoder.py CHANGED
@@ -1,8 +1,9 @@
1
  """
2
  MINDI 1.5 Vision-Coder β€” Vision Encoder
3
 
4
- Uses CLIP ViT-L/14 to encode UI screenshots into embeddings
5
- that the coding model can understand and critique.
 
6
  """
7
 
8
  from __future__ import annotations
@@ -13,79 +14,237 @@ from typing import Optional
13
  import torch
14
  import torch.nn as nn
15
  from PIL import Image
16
- from transformers import CLIPModel, CLIPProcessor
17
 
18
 
19
  class VisionEncoder(nn.Module):
20
- """CLIP-based vision encoder for UI screenshot understanding."""
 
 
 
 
 
 
 
 
21
 
22
  def __init__(
23
  self,
24
  model_name: str = "openai/clip-vit-large-patch14",
25
- projection_dim: int = 768,
26
  device: Optional[str] = None,
27
  cache_dir: Optional[Path] = None,
 
28
  ) -> None:
 
 
 
 
 
 
 
 
 
 
29
  super().__init__()
 
 
30
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
31
- self.cache_dir = cache_dir or Path("./checkpoints/vision")
32
  self.cache_dir.mkdir(parents=True, exist_ok=True)
33
 
34
- # Load CLIP model and processor
35
- self.clip: CLIPModel = CLIPModel.from_pretrained(
36
- model_name, cache_dir=str(self.cache_dir)
 
 
 
37
  )
38
- self.processor: CLIPProcessor = CLIPProcessor.from_pretrained(
39
- model_name, cache_dir=str(self.cache_dir)
 
40
  )
41
 
42
- # Freeze CLIP backbone β€” we only train the projection layer
43
  for param in self.clip.parameters():
44
  param.requires_grad = False
 
45
 
46
- # Trainable projection: CLIP hidden β†’ LLM embedding space
47
- clip_hidden_size: int = self.clip.config.vision_config.hidden_size # 1024
48
- self.projection = nn.Sequential(
49
- nn.Linear(clip_hidden_size, projection_dim),
50
- nn.GELU(),
51
- nn.Linear(projection_dim, projection_dim),
52
- )
53
 
54
  self.to(self.device)
55
 
56
- def encode_image(self, image: Image.Image) -> torch.Tensor:
57
- """Encode a PIL image into a projected embedding tensor."""
58
- inputs = self.processor(images=image, return_tensors="pt")
59
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  with torch.no_grad():
62
- vision_outputs = self.clip.vision_model(**inputs)
63
- # Use [CLS] token embedding
64
- cls_embedding = vision_outputs.last_hidden_state[:, 0, :]
65
 
66
- # Project into LLM embedding space (this part IS trainable)
67
- projected = self.projection(cls_embedding)
68
  return projected
69
 
70
- def encode_screenshot(self, screenshot_path: Path) -> torch.Tensor:
71
- """Load a screenshot from disk and encode it."""
72
- if not screenshot_path.exists():
73
- raise FileNotFoundError(f"Screenshot not found: {screenshot_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
- image = Image.open(screenshot_path).convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  return self.encode_image(image)
77
 
78
  def save_projection(self, save_dir: Optional[Path] = None) -> Path:
79
- """Save only the trainable projection weights."""
80
- save_path = save_dir or self.cache_dir / "projection"
 
 
 
 
 
 
 
 
81
  save_path.mkdir(parents=True, exist_ok=True)
82
  torch.save(self.projection.state_dict(), save_path / "projection.pt")
 
83
  return save_path
84
 
85
  def load_projection(self, load_dir: Path) -> None:
86
- """Load projection weights from disk."""
87
- weights_path = load_dir / "projection.pt"
 
 
 
 
 
88
  if not weights_path.exists():
89
  raise FileNotFoundError(f"Projection weights not found: {weights_path}")
90
  state_dict = torch.load(weights_path, map_location=self.device, weights_only=True)
91
  self.projection.load_state_dict(state_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
  MINDI 1.5 Vision-Coder β€” Vision Encoder
3
 
4
+ Uses CLIP ViT-L/14 (frozen) to encode UI screenshots into 256 visual
5
+ tokens projected from 1024 β†’ 4096 to match the Qwen hidden dimension.
6
+ Output shape: (batch, 256, 4096).
7
  """
8
 
9
  from __future__ import annotations
 
14
  import torch
15
  import torch.nn as nn
16
  from PIL import Image
17
+ from transformers import CLIPImageProcessor, CLIPVisionModel
18
 
19
 
20
  class VisionEncoder(nn.Module):
21
+ """
22
+ CLIP ViT-L/14 vision encoder for MINDI 1.5.
23
+
24
+ Extracts ALL 256 patch tokens (excludes CLS) from CLIP and
25
+ projects them from 1024 β†’ 4096 to match Qwen2.5 hidden_size.
26
+ The CLIP backbone is frozen; only the projection layer trains.
27
+ """
28
+
29
+ NUM_PATCHES: int = 256 # ViT-L/14: 16Γ—16 patches from 224Γ—224
30
 
31
  def __init__(
32
  self,
33
  model_name: str = "openai/clip-vit-large-patch14",
34
+ llm_hidden_size: int = 4096,
35
  device: Optional[str] = None,
36
  cache_dir: Optional[Path] = None,
37
+ torch_dtype: torch.dtype = torch.float32,
38
  ) -> None:
39
+ """
40
+ Initialize the vision encoder.
41
+
42
+ Args:
43
+ model_name: HuggingFace CLIP vision model identifier.
44
+ llm_hidden_size: Target projection dimension (must match LLM hidden_size).
45
+ device: Target device ('cuda', 'cpu', or None for auto).
46
+ cache_dir: Local directory for model weight cache.
47
+ torch_dtype: Data type for CLIP weights (projection always float32).
48
+ """
49
  super().__init__()
50
+ self.model_name = model_name
51
+ self.llm_hidden_size = llm_hidden_size
52
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
53
+ self.cache_dir = Path(cache_dir) if cache_dir else Path("./checkpoints/vision")
54
  self.cache_dir.mkdir(parents=True, exist_ok=True)
55
 
56
+ # Load CLIP vision model (no text tower) and image processor
57
+ print(f"[VisionEncoder] Loading {model_name} ...")
58
+ self.clip = CLIPVisionModel.from_pretrained(
59
+ model_name,
60
+ cache_dir=str(self.cache_dir),
61
+ torch_dtype=torch_dtype,
62
  )
63
+ self.image_processor = CLIPImageProcessor.from_pretrained(
64
+ model_name,
65
+ cache_dir=str(self.cache_dir),
66
  )
67
 
68
+ # Freeze entire CLIP backbone
69
  for param in self.clip.parameters():
70
  param.requires_grad = False
71
+ self.clip.eval()
72
 
73
+ # Trainable projection: CLIP hidden (1024) β†’ LLM hidden (4096)
74
+ clip_hidden_size: int = self.clip.config.hidden_size # 1024
75
+ self.projection = nn.Linear(clip_hidden_size, self.llm_hidden_size)
 
 
 
 
76
 
77
  self.to(self.device)
78
 
79
+ trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
80
+ total = sum(p.numel() for p in self.parameters())
81
+ print(f"[VisionEncoder] Loaded β€” {clip_hidden_size} β†’ {self.llm_hidden_size}")
82
+ print(f" Trainable: {trainable:,} | Total: {total:,}")
83
+
84
+ def encode_image(self, image: Optional[Image.Image]) -> Optional[torch.Tensor]:
85
+ """
86
+ Encode a single PIL image into projected patch token embeddings.
87
+
88
+ Args:
89
+ image: A PIL Image (RGB), or None.
90
+
91
+ Returns:
92
+ Tensor of shape (1, 256, 4096) or None if input is None.
93
+ """
94
+ if image is None:
95
+ return None
96
+
97
+ inputs = self.image_processor(images=image, return_tensors="pt")
98
+ pixel_values = inputs["pixel_values"].to(device=self.device, dtype=self.clip.dtype)
99
 
100
  with torch.no_grad():
101
+ vision_outputs = self.clip(pixel_values=pixel_values)
102
+ # last_hidden_state: (batch, 257, 1024) β€” 1 CLS + 256 patches
103
+ patch_tokens = vision_outputs.last_hidden_state[:, 1:, :] # (1, 256, 1024)
104
 
105
+ # Project into LLM embedding space (trainable)
106
+ projected = self.projection(patch_tokens.float()) # (1, 256, 4096)
107
  return projected
108
 
109
+ def encode_batch(self, images: list[Optional[Image.Image]]) -> list[Optional[torch.Tensor]]:
110
+ """
111
+ Encode a batch of images. None entries pass through as None.
112
+
113
+ Args:
114
+ images: List of PIL Images or Nones.
115
+
116
+ Returns:
117
+ List of tensors (1, 256, 4096) or Nones matching input order.
118
+ """
119
+ results: list[Optional[torch.Tensor]] = [None] * len(images)
120
+ valid_indices = [i for i, img in enumerate(images) if img is not None]
121
+
122
+ if not valid_indices:
123
+ return results
124
+
125
+ valid_images = [images[i] for i in valid_indices]
126
+ inputs = self.image_processor(images=valid_images, return_tensors="pt")
127
+ pixel_values = inputs["pixel_values"].to(device=self.device, dtype=self.clip.dtype)
128
 
129
+ with torch.no_grad():
130
+ vision_outputs = self.clip(pixel_values=pixel_values)
131
+ patch_tokens = vision_outputs.last_hidden_state[:, 1:, :] # (N, 256, 1024)
132
+
133
+ projected = self.projection(patch_tokens.float()) # (N, 256, 4096)
134
+
135
+ for batch_idx, orig_idx in enumerate(valid_indices):
136
+ results[orig_idx] = projected[batch_idx].unsqueeze(0) # (1, 256, 4096)
137
+
138
+ return results
139
+
140
+ def encode_screenshot(self, screenshot_path: Path) -> Optional[torch.Tensor]:
141
+ """
142
+ Load a screenshot from disk and encode it.
143
+
144
+ Args:
145
+ screenshot_path: Path to image file.
146
+
147
+ Returns:
148
+ Tensor of shape (1, 256, 4096).
149
+ """
150
+ path = Path(screenshot_path)
151
+ if not path.exists():
152
+ raise FileNotFoundError(f"Screenshot not found: {path}")
153
+ image = Image.open(path).convert("RGB")
154
  return self.encode_image(image)
155
 
156
  def save_projection(self, save_dir: Optional[Path] = None) -> Path:
157
+ """
158
+ Save only the trainable projection weights.
159
+
160
+ Args:
161
+ save_dir: Directory to save to. Defaults to cache_dir/projection.
162
+
163
+ Returns:
164
+ Path where weights were saved.
165
+ """
166
+ save_path = Path(save_dir) if save_dir else self.cache_dir / "projection"
167
  save_path.mkdir(parents=True, exist_ok=True)
168
  torch.save(self.projection.state_dict(), save_path / "projection.pt")
169
+ print(f"[VisionEncoder] Projection saved to {save_path}")
170
  return save_path
171
 
172
  def load_projection(self, load_dir: Path) -> None:
173
+ """
174
+ Load projection weights from disk.
175
+
176
+ Args:
177
+ load_dir: Directory containing projection.pt.
178
+ """
179
+ weights_path = Path(load_dir) / "projection.pt"
180
  if not weights_path.exists():
181
  raise FileNotFoundError(f"Projection weights not found: {weights_path}")
182
  state_dict = torch.load(weights_path, map_location=self.device, weights_only=True)
183
  self.projection.load_state_dict(state_dict)
184
+ print(f"[VisionEncoder] Projection loaded from {load_dir}")
185
+
186
+ def get_num_visual_tokens(self) -> int:
187
+ """Return the number of visual tokens produced per image (256)."""
188
+ return self.NUM_PATCHES
189
+
190
+
191
+ # ── Test block ────────────────────────────────────────────────────────
192
+ if __name__ == "__main__":
193
+ print("=" * 60)
194
+ print(" MINDI 1.5 β€” Vision Encoder Test")
195
+ print("=" * 60)
196
+ print()
197
+
198
+ # 1. Initialize encoder
199
+ encoder = VisionEncoder(
200
+ model_name="openai/clip-vit-large-patch14",
201
+ llm_hidden_size=4096,
202
+ )
203
+
204
+ # 2. Create a dummy image (224Γ—224 RGB)
205
+ dummy_image = Image.new("RGB", (224, 224), color=(128, 128, 128))
206
+
207
+ # 3. Encode single image
208
+ print("\n Encoding single image ...")
209
+ output = encoder.encode_image(dummy_image)
210
+ assert output is not None
211
+ print(f" Output shape: {output.shape}")
212
+ assert output.shape == (1, 256, 4096), f"Expected (1, 256, 4096), got {output.shape}"
213
+
214
+ # 4. Encode None β†’ should return None
215
+ none_output = encoder.encode_image(None)
216
+ assert none_output is None, "Expected None for None input"
217
+ print(" None input β†’ None output βœ“")
218
+
219
+ # 5. Encode batch (mixed with None)
220
+ print("\n Encoding batch [image, None, image] ...")
221
+ batch_results = encoder.encode_batch([dummy_image, None, dummy_image])
222
+ assert batch_results[0] is not None and batch_results[0].shape == (1, 256, 4096)
223
+ assert batch_results[1] is None
224
+ assert batch_results[2] is not None and batch_results[2].shape == (1, 256, 4096)
225
+ print(f" Batch results: [{batch_results[0].shape}, None, {batch_results[2].shape}]")
226
+
227
+ # 6. Check trainable params (only projection should train)
228
+ trainable = sum(p.numel() for p in encoder.parameters() if p.requires_grad)
229
+ frozen = sum(p.numel() for p in encoder.parameters() if not p.requires_grad)
230
+ print(f"\n Trainable: {trainable:,}")
231
+ print(f" Frozen: {frozen:,}")
232
+ assert trainable == 1024 * 4096 + 4096, f"Unexpected trainable count: {trainable}"
233
+ assert frozen > trainable, "CLIP backbone should be frozen"
234
+
235
+ # 7. Save and reload projection
236
+ print("\n Testing save/load projection ...")
237
+ import tempfile
238
+ with tempfile.TemporaryDirectory() as tmp:
239
+ save_path = encoder.save_projection(Path(tmp))
240
+ old_weight = encoder.projection.weight.clone()
241
+ # Perturb weights
242
+ encoder.projection.weight.data.fill_(0.0)
243
+ assert not torch.equal(encoder.projection.weight, old_weight)
244
+ # Reload
245
+ encoder.load_projection(Path(tmp))
246
+ assert torch.equal(encoder.projection.weight, old_weight), "Weights not restored!"
247
+ print(" Save/load round-trip βœ“")
248
+
249
+ print("\n βœ“ All vision encoder tests passed!")
250
+ print("=" * 60)
src/training/mindi_trainer.py ADDED
@@ -0,0 +1,745 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MINDI 1.5 Vision-Coder β€” Trainer
3
+
4
+ Production-ready 3-phase training loop optimized for AMD MI300X (192GB VRAM).
5
+ Streams training data from disk (4.18GB train.jsonl) to avoid RAM exhaustion.
6
+
7
+ Phases:
8
+ Phase 1 (steps 0–5000): LoRA only, LR 2e-4, batch 16
9
+ Phase 2 (steps 5000–7500): Vision bridge only, LR 1e-5, batch 8
10
+ Phase 3 (steps 7500–10000): All trainable, LR 5e-5, batch 12
11
+
12
+ MI300X specifics:
13
+ - ROCm presents as CUDA to PyTorch (torch.cuda.* works)
14
+ - bf16 (NOT fp16) for AMD stability
15
+ - torch.compile() optional (works on ROCm)
16
+ - Gradient checkpointing enabled
17
+ - DataLoader: num_workers=4, pin_memory=True, prefetch_factor=2
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import json
23
+ import math
24
+ import time
25
+ from dataclasses import dataclass, field
26
+ from pathlib import Path
27
+ from typing import Any, Iterator, Optional
28
+
29
+ import torch
30
+ import torch.nn as nn
31
+ from torch.optim import AdamW
32
+ from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
33
+ from torch.utils.data import DataLoader, IterableDataset
34
+
35
+ # ── Configuration ─────────────────────────────────────────────────────
36
+
37
+ PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
38
+
39
+
40
+ @dataclass
41
+ class PhaseConfig:
42
+ """Configuration for a single training phase."""
43
+ name: str
44
+ start_step: int
45
+ end_step: int
46
+ learning_rate: float
47
+ batch_size: int
48
+ gradient_accumulation_steps: int = 4
49
+ # Component toggles
50
+ lora: bool = False
51
+ vision_projection: bool = False
52
+ fusion: bool = False
53
+
54
+
55
+ @dataclass
56
+ class TrainingConfig:
57
+ """Full training configuration."""
58
+
59
+ # Data paths
60
+ train_file: Path = field(default_factory=lambda: PROJECT_ROOT / "data" / "processed" / "train.jsonl")
61
+ val_file: Path = field(default_factory=lambda: PROJECT_ROOT / "data" / "processed" / "val.jsonl")
62
+
63
+ # Output
64
+ output_dir: Path = field(default_factory=lambda: PROJECT_ROOT / "checkpoints" / "training")
65
+ log_dir: Path = field(default_factory=lambda: PROJECT_ROOT / "logs" / "training")
66
+
67
+ # Model
68
+ max_seq_length: int = 8192
69
+ use_compile: bool = False
70
+ gradient_checkpointing: bool = True
71
+
72
+ # Hardware (MI300X defaults)
73
+ dtype: str = "bf16"
74
+ num_workers: int = 4
75
+ pin_memory: bool = True
76
+ prefetch_factor: int = 2
77
+
78
+ # Training
79
+ weight_decay: float = 0.01
80
+ warmup_ratio: float = 0.03
81
+ max_grad_norm: float = 1.0
82
+ seed: int = 42
83
+
84
+ # Logging
85
+ log_every_n_steps: int = 10
86
+ eval_every_n_steps: int = 250
87
+ save_every_n_steps: int = 500
88
+
89
+ # Phases
90
+ phases: list[PhaseConfig] = field(default_factory=lambda: [
91
+ PhaseConfig(
92
+ name="phase1_lora",
93
+ start_step=0, end_step=5000,
94
+ learning_rate=2e-4, batch_size=16,
95
+ lora=True, vision_projection=False, fusion=False,
96
+ ),
97
+ PhaseConfig(
98
+ name="phase2_vision_bridge",
99
+ start_step=5000, end_step=7500,
100
+ learning_rate=1e-5, batch_size=8,
101
+ lora=False, vision_projection=True, fusion=True,
102
+ ),
103
+ PhaseConfig(
104
+ name="phase3_all",
105
+ start_step=7500, end_step=10000,
106
+ learning_rate=5e-5, batch_size=12,
107
+ lora=True, vision_projection=True, fusion=True,
108
+ ),
109
+ ])
110
+
111
+ @property
112
+ def total_steps(self) -> int:
113
+ return self.phases[-1].end_step if self.phases else 0
114
+
115
+ @property
116
+ def torch_dtype(self) -> torch.dtype:
117
+ return {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[self.dtype]
118
+
119
+
120
+ # ── Streaming Dataset ─────────────────────────────────────────────────
121
+
122
+ class StreamingJSONLDataset(IterableDataset):
123
+ """
124
+ Streams JSONL training data from disk line by line.
125
+ Tokenizes on-the-fly to avoid loading 4+ GB into RAM.
126
+
127
+ Expected JSONL format:
128
+ {"id": "...", "type": "...", "source": "...",
129
+ "messages": [{"role": "system", "content": "..."},
130
+ {"role": "user", "content": "..."},
131
+ {"role": "assistant", "content": "..."}],
132
+ "metadata": {...}}
133
+ """
134
+
135
+ def __init__(
136
+ self,
137
+ file_path: Path,
138
+ tokenizer: Any,
139
+ max_length: int = 8192,
140
+ shuffle_buffer: int = 10000,
141
+ seed: int = 42,
142
+ ) -> None:
143
+ self.file_path = Path(file_path)
144
+ self.tokenizer = tokenizer
145
+ self.max_length = max_length
146
+ self.shuffle_buffer = shuffle_buffer
147
+ self.seed = seed
148
+
149
+ if not self.file_path.exists():
150
+ raise FileNotFoundError(f"Training data not found: {self.file_path}")
151
+
152
+ def _format_messages(self, messages: list[dict[str, str]]) -> str:
153
+ """Format chat messages into a single training string."""
154
+ # Use the tokenizer's chat template if available
155
+ if hasattr(self.tokenizer, "apply_chat_template"):
156
+ return self.tokenizer.apply_chat_template(
157
+ messages, tokenize=False, add_generation_prompt=False
158
+ )
159
+ # Fallback: simple concatenation
160
+ parts = []
161
+ for msg in messages:
162
+ role = msg.get("role", "user")
163
+ content = msg.get("content", "")
164
+ parts.append(f"<|{role}|>\n{content}")
165
+ return "\n".join(parts)
166
+
167
+ def _tokenize(self, text: str) -> Optional[dict[str, torch.Tensor]]:
168
+ """Tokenize text and create training labels."""
169
+ encoded = self.tokenizer(
170
+ text,
171
+ max_length=self.max_length,
172
+ truncation=True,
173
+ padding="max_length",
174
+ return_tensors="pt",
175
+ )
176
+ input_ids = encoded["input_ids"].squeeze(0)
177
+ attention_mask = encoded["attention_mask"].squeeze(0)
178
+
179
+ # Labels = input_ids, with padding tokens masked as -100
180
+ labels = input_ids.clone()
181
+ labels[attention_mask == 0] = -100
182
+
183
+ return {
184
+ "input_ids": input_ids,
185
+ "attention_mask": attention_mask,
186
+ "labels": labels,
187
+ }
188
+
189
+ def _line_iterator(self) -> Iterator[dict]:
190
+ """Iterate over JSONL file line by line."""
191
+ with open(self.file_path, "r", encoding="utf-8") as f:
192
+ for line in f:
193
+ line = line.strip()
194
+ if line:
195
+ yield json.loads(line)
196
+
197
+ def _shuffled_iterator(self) -> Iterator[dict]:
198
+ """Reservoir-style shuffle buffer for streaming data."""
199
+ import random
200
+ rng = random.Random(self.seed)
201
+ buffer: list[dict] = []
202
+
203
+ for item in self._line_iterator():
204
+ buffer.append(item)
205
+ if len(buffer) >= self.shuffle_buffer:
206
+ rng.shuffle(buffer)
207
+ yield from buffer
208
+ buffer.clear()
209
+
210
+ # Flush remaining items
211
+ if buffer:
212
+ rng.shuffle(buffer)
213
+ yield from buffer
214
+
215
+ def __iter__(self) -> Iterator[dict[str, torch.Tensor]]:
216
+ for example in self._shuffled_iterator():
217
+ messages = example.get("messages", [])
218
+ if not messages:
219
+ continue
220
+ text = self._format_messages(messages)
221
+ tokenized = self._tokenize(text)
222
+ if tokenized is not None:
223
+ yield tokenized
224
+
225
+ def count_lines(self) -> int:
226
+ """Count total lines (for progress estimation). Reads file once."""
227
+ count = 0
228
+ with open(self.file_path, "r", encoding="utf-8") as f:
229
+ for _ in f:
230
+ count += 1
231
+ return count
232
+
233
+
234
+ # ── Trainer ───────────────────────────────────────────────────────────
235
+
236
+ class MINDITrainer:
237
+ """
238
+ 3-phase trainer for MINDI 1.5 Vision-Coder.
239
+
240
+ Optimized for AMD MI300X 192GB:
241
+ - bf16 mixed precision
242
+ - Gradient checkpointing
243
+ - Streaming data from disk
244
+ - Optional torch.compile()
245
+ - Phase-based component freezing/unfreezing
246
+ """
247
+
248
+ def __init__(
249
+ self,
250
+ model: nn.Module,
251
+ config: TrainingConfig,
252
+ ) -> None:
253
+ """
254
+ Initialize the trainer.
255
+
256
+ Args:
257
+ model: MINDI15 model instance (already initialized).
258
+ config: Training configuration.
259
+ """
260
+ self.model = model
261
+ self.config = config
262
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
263
+ self.global_step = 0
264
+ self.best_val_loss = float("inf")
265
+
266
+ # Create output directories
267
+ self.config.output_dir.mkdir(parents=True, exist_ok=True)
268
+ self.config.log_dir.mkdir(parents=True, exist_ok=True)
269
+
270
+ # Gradient checkpointing
271
+ if config.gradient_checkpointing:
272
+ base_model = self.model.architecture.get_model()
273
+ if hasattr(base_model, "gradient_checkpointing_enable"):
274
+ base_model.gradient_checkpointing_enable()
275
+ print("[MINDITrainer] Gradient checkpointing enabled")
276
+
277
+ # Optional torch.compile (works on ROCm)
278
+ if config.use_compile:
279
+ print("[MINDITrainer] Compiling model with torch.compile() ...")
280
+ self.model.architecture.peft_model = torch.compile(
281
+ self.model.architecture.peft_model
282
+ )
283
+ print("[MINDITrainer] Compilation complete")
284
+
285
+ # Mixed precision scaler (bf16 doesn't need GradScaler, but keep structure)
286
+ self.use_amp = config.dtype in ("bf16", "fp16")
287
+ self.amp_dtype = config.torch_dtype
288
+
289
+ # Training log
290
+ self.log_file = config.log_dir / "training_log.jsonl"
291
+ self.metrics_history: list[dict] = []
292
+
293
+ print(f"[MINDITrainer] Device: {self.device}")
294
+ print(f"[MINDITrainer] Dtype: {config.dtype}")
295
+ print(f"[MINDITrainer] Total steps: {config.total_steps}")
296
+ print(f"[MINDITrainer] Phases: {len(config.phases)}")
297
+
298
+ def _build_optimizer(self, phase: PhaseConfig) -> AdamW:
299
+ """Build optimizer for the current phase (only trainable params)."""
300
+ params = [p for p in self.model.parameters() if p.requires_grad]
301
+ if not params:
302
+ raise RuntimeError(f"No trainable parameters in phase '{phase.name}'")
303
+ return AdamW(
304
+ params,
305
+ lr=phase.learning_rate,
306
+ weight_decay=self.config.weight_decay,
307
+ betas=(0.9, 0.95),
308
+ )
309
+
310
+ def _build_scheduler(
311
+ self, optimizer: AdamW, phase: PhaseConfig
312
+ ) -> torch.optim.lr_scheduler.LRScheduler:
313
+ """Build LR scheduler: linear warmup + cosine decay."""
314
+ phase_steps = phase.end_step - phase.start_step
315
+ warmup_steps = max(1, int(phase_steps * self.config.warmup_ratio))
316
+ decay_steps = max(1, phase_steps - warmup_steps)
317
+
318
+ warmup = LinearLR(
319
+ optimizer,
320
+ start_factor=0.01,
321
+ end_factor=1.0,
322
+ total_iters=warmup_steps,
323
+ )
324
+ cosine = CosineAnnealingLR(
325
+ optimizer,
326
+ T_max=decay_steps,
327
+ eta_min=phase.learning_rate * 0.1,
328
+ )
329
+ return SequentialLR(
330
+ optimizer,
331
+ schedulers=[warmup, cosine],
332
+ milestones=[warmup_steps],
333
+ )
334
+
335
+ def _build_dataloader(
336
+ self, file_path: Path, batch_size: int, shuffle_buffer: int = 10000
337
+ ) -> DataLoader:
338
+ """Build a streaming DataLoader."""
339
+ dataset = StreamingJSONLDataset(
340
+ file_path=file_path,
341
+ tokenizer=self.model.tokenizer,
342
+ max_length=self.config.max_seq_length,
343
+ shuffle_buffer=shuffle_buffer,
344
+ seed=self.config.seed,
345
+ )
346
+ return DataLoader(
347
+ dataset,
348
+ batch_size=batch_size,
349
+ num_workers=self.config.num_workers,
350
+ pin_memory=self.config.pin_memory,
351
+ prefetch_factor=self.config.prefetch_factor if self.config.num_workers > 0 else None,
352
+ drop_last=True,
353
+ )
354
+
355
+ def _log_metrics(self, metrics: dict) -> None:
356
+ """Append metrics to log file and history."""
357
+ self.metrics_history.append(metrics)
358
+ with open(self.log_file, "a", encoding="utf-8") as f:
359
+ f.write(json.dumps(metrics) + "\n")
360
+
361
+ @torch.no_grad()
362
+ def evaluate(self, val_loader: DataLoader, max_batches: int = 50) -> float:
363
+ """
364
+ Run validation and return average loss.
365
+
366
+ Args:
367
+ val_loader: Validation DataLoader.
368
+ max_batches: Maximum batches to evaluate (for speed).
369
+
370
+ Returns:
371
+ Average validation loss.
372
+ """
373
+ self.model.eval()
374
+ total_loss = 0.0
375
+ count = 0
376
+
377
+ for batch_idx, batch in enumerate(val_loader):
378
+ if batch_idx >= max_batches:
379
+ break
380
+
381
+ input_ids = batch["input_ids"].to(self.device)
382
+ attention_mask = batch["attention_mask"].to(self.device)
383
+ labels = batch["labels"].to(self.device)
384
+
385
+ with torch.autocast(device_type="cuda", dtype=self.amp_dtype, enabled=self.use_amp):
386
+ result = self.model(
387
+ input_ids=input_ids,
388
+ attention_mask=attention_mask,
389
+ labels=labels,
390
+ )
391
+
392
+ if result["loss"] is not None:
393
+ total_loss += result["loss"].item()
394
+ count += 1
395
+
396
+ self.model.train()
397
+ return total_loss / max(count, 1)
398
+
399
+ def _save_checkpoint(self, phase_name: str, step: int, val_loss: float) -> Path:
400
+ """Save a training checkpoint."""
401
+ ckpt_dir = self.config.output_dir / f"{phase_name}_step{step}"
402
+ ckpt_dir.mkdir(parents=True, exist_ok=True)
403
+
404
+ # Save model weights
405
+ self.model.save(ckpt_dir)
406
+
407
+ # Save trainer state
408
+ state = {
409
+ "global_step": self.global_step,
410
+ "phase": phase_name,
411
+ "step_in_phase": step,
412
+ "val_loss": val_loss,
413
+ "best_val_loss": self.best_val_loss,
414
+ }
415
+ torch.save(state, ckpt_dir / "trainer_state.pt")
416
+
417
+ print(f"[MINDITrainer] Checkpoint saved: {ckpt_dir}")
418
+ return ckpt_dir
419
+
420
+ def train_phase(self, phase: PhaseConfig) -> dict:
421
+ """
422
+ Execute a single training phase.
423
+
424
+ Args:
425
+ phase: Phase configuration.
426
+
427
+ Returns:
428
+ Dict with phase training metrics.
429
+ """
430
+ print()
431
+ print("=" * 60)
432
+ print(f" Phase: {phase.name}")
433
+ print(f" Steps: {phase.start_step} β†’ {phase.end_step}")
434
+ print(f" LR: {phase.learning_rate} | Batch: {phase.batch_size}")
435
+ print(f" Components: LoRA={phase.lora}, Vision={phase.vision_projection}, "
436
+ f"Fusion={phase.fusion}")
437
+ print("=" * 60)
438
+
439
+ # Set trainable components
440
+ self.model.set_trainable_components(
441
+ lora=phase.lora,
442
+ vision_projection=phase.vision_projection,
443
+ fusion=phase.fusion,
444
+ )
445
+
446
+ # Build optimizer and scheduler for this phase
447
+ optimizer = self._build_optimizer(phase)
448
+ scheduler = self._build_scheduler(optimizer, phase)
449
+
450
+ # Build data loaders
451
+ train_loader = self._build_dataloader(
452
+ self.config.train_file, phase.batch_size
453
+ )
454
+ val_loader = self._build_dataloader(
455
+ self.config.val_file, batch_size=max(phase.batch_size // 2, 1),
456
+ shuffle_buffer=1000,
457
+ )
458
+
459
+ self.model.train()
460
+ phase_steps = phase.end_step - phase.start_step
461
+ step_in_phase = 0
462
+ accum_loss = 0.0
463
+ accum_count = 0
464
+ phase_start_time = time.time()
465
+
466
+ train_iter = iter(train_loader)
467
+
468
+ while step_in_phase < phase_steps:
469
+ # Get next batch (restart iterator if exhausted = new epoch)
470
+ try:
471
+ batch = next(train_iter)
472
+ except StopIteration:
473
+ train_iter = iter(train_loader)
474
+ batch = next(train_iter)
475
+
476
+ input_ids = batch["input_ids"].to(self.device)
477
+ attention_mask = batch["attention_mask"].to(self.device)
478
+ labels = batch["labels"].to(self.device)
479
+
480
+ # Forward pass with mixed precision
481
+ with torch.autocast(device_type="cuda", dtype=self.amp_dtype, enabled=self.use_amp):
482
+ result = self.model(
483
+ input_ids=input_ids,
484
+ attention_mask=attention_mask,
485
+ labels=labels,
486
+ )
487
+ loss = result["loss"]
488
+
489
+ if loss is None:
490
+ continue
491
+
492
+ # Scale loss for gradient accumulation
493
+ loss = loss / phase.gradient_accumulation_steps
494
+
495
+ # Backward pass
496
+ loss.backward()
497
+ accum_loss += loss.item() * phase.gradient_accumulation_steps
498
+ accum_count += 1
499
+
500
+ # Optimizer step (every gradient_accumulation_steps)
501
+ if accum_count % phase.gradient_accumulation_steps == 0:
502
+ # Gradient clipping
503
+ torch.nn.utils.clip_grad_norm_(
504
+ [p for p in self.model.parameters() if p.requires_grad],
505
+ self.config.max_grad_norm,
506
+ )
507
+ optimizer.step()
508
+ scheduler.step()
509
+ optimizer.zero_grad()
510
+
511
+ step_in_phase += 1
512
+ self.global_step += 1
513
+ avg_loss = accum_loss / phase.gradient_accumulation_steps
514
+ accum_loss = 0.0
515
+
516
+ # Logging
517
+ if step_in_phase % self.config.log_every_n_steps == 0:
518
+ elapsed = time.time() - phase_start_time
519
+ steps_per_sec = step_in_phase / elapsed if elapsed > 0 else 0.0
520
+ eta_sec = (phase_steps - step_in_phase) / steps_per_sec if steps_per_sec > 0 else 0.0
521
+
522
+ metrics = {
523
+ "phase": phase.name,
524
+ "global_step": self.global_step,
525
+ "step_in_phase": step_in_phase,
526
+ "loss": round(avg_loss, 4),
527
+ "lr": optimizer.param_groups[0]["lr"],
528
+ "steps_per_sec": round(steps_per_sec, 3),
529
+ "eta_minutes": round(eta_sec / 60, 1),
530
+ "elapsed_minutes": round(elapsed / 60, 1),
531
+ }
532
+ self._log_metrics(metrics)
533
+ print(f" [{phase.name}] step {step_in_phase}/{phase_steps} | "
534
+ f"loss={avg_loss:.4f} | "
535
+ f"lr={optimizer.param_groups[0]['lr']:.2e} | "
536
+ f"speed={steps_per_sec:.2f} steps/s | "
537
+ f"ETA={eta_sec / 60:.1f}min")
538
+
539
+ # Evaluation
540
+ if step_in_phase % self.config.eval_every_n_steps == 0:
541
+ val_loss = self.evaluate(val_loader)
542
+ print(f" [{phase.name}] EVAL step {step_in_phase} | val_loss={val_loss:.4f}")
543
+ self._log_metrics({
544
+ "phase": phase.name,
545
+ "global_step": self.global_step,
546
+ "val_loss": round(val_loss, 4),
547
+ "type": "eval",
548
+ })
549
+
550
+ # Save best model
551
+ if val_loss < self.best_val_loss:
552
+ self.best_val_loss = val_loss
553
+ self._save_checkpoint(phase.name, step_in_phase, val_loss)
554
+ print(f" [{phase.name}] New best val_loss: {val_loss:.4f}")
555
+
556
+ # Periodic save
557
+ if step_in_phase % self.config.save_every_n_steps == 0:
558
+ self._save_checkpoint(phase.name, step_in_phase, self.best_val_loss)
559
+
560
+ # End-of-phase save
561
+ phase_elapsed = time.time() - phase_start_time
562
+ self._save_checkpoint(phase.name, step_in_phase, self.best_val_loss)
563
+
564
+ phase_summary = {
565
+ "phase": phase.name,
566
+ "total_steps": step_in_phase,
567
+ "elapsed_minutes": round(phase_elapsed / 60, 1),
568
+ "best_val_loss": round(self.best_val_loss, 4),
569
+ "type": "phase_complete",
570
+ }
571
+ self._log_metrics(phase_summary)
572
+ print(f"\n [{phase.name}] Complete β€” {step_in_phase} steps in "
573
+ f"{phase_elapsed / 60:.1f} min")
574
+
575
+ return phase_summary
576
+
577
+ def train(self) -> dict:
578
+ """
579
+ Run all 3 training phases sequentially.
580
+
581
+ Returns:
582
+ Dict with complete training summary.
583
+ """
584
+ print()
585
+ print("=" * 60)
586
+ print(" MINDI 1.5 β€” Training Start")
587
+ print(f" Total phases: {len(self.config.phases)}")
588
+ print(f" Total steps: {self.config.total_steps}")
589
+ print(f" Device: {self.device}")
590
+ print(f" Dtype: {self.config.dtype}")
591
+ print(f" Output: {self.config.output_dir}")
592
+ print("=" * 60)
593
+
594
+ torch.manual_seed(self.config.seed)
595
+ if torch.cuda.is_available():
596
+ torch.cuda.manual_seed_all(self.config.seed)
597
+
598
+ training_start = time.time()
599
+ phase_summaries = []
600
+
601
+ for phase in self.config.phases:
602
+ summary = self.train_phase(phase)
603
+ phase_summaries.append(summary)
604
+
605
+ total_elapsed = time.time() - training_start
606
+
607
+ # Final save
608
+ final_dir = self.config.output_dir / "final"
609
+ final_dir.mkdir(parents=True, exist_ok=True)
610
+ self.model.save(final_dir)
611
+
612
+ training_summary = {
613
+ "total_steps": self.global_step,
614
+ "total_minutes": round(total_elapsed / 60, 1),
615
+ "best_val_loss": round(self.best_val_loss, 4),
616
+ "phases": phase_summaries,
617
+ "type": "training_complete",
618
+ }
619
+ self._log_metrics(training_summary)
620
+
621
+ print()
622
+ print("=" * 60)
623
+ print(" MINDI 1.5 β€” Training Complete")
624
+ print(f" Total steps: {self.global_step}")
625
+ print(f" Total time: {total_elapsed / 60:.1f} minutes")
626
+ print(f" Best val loss: {self.best_val_loss:.4f}")
627
+ print(f" Final saved to: {final_dir}")
628
+ print("=" * 60)
629
+
630
+ return training_summary
631
+
632
+ def resume_from_checkpoint(self, checkpoint_dir: Path) -> None:
633
+ """
634
+ Resume training from a checkpoint.
635
+
636
+ Args:
637
+ checkpoint_dir: Directory containing saved checkpoint.
638
+ """
639
+ checkpoint_dir = Path(checkpoint_dir)
640
+ state_file = checkpoint_dir / "trainer_state.pt"
641
+
642
+ if not state_file.exists():
643
+ raise FileNotFoundError(f"Trainer state not found: {state_file}")
644
+
645
+ # Load model weights
646
+ self.model.load(checkpoint_dir)
647
+
648
+ # Load trainer state
649
+ state = torch.load(state_file, map_location=self.device, weights_only=True)
650
+ self.global_step = state["global_step"]
651
+ self.best_val_loss = state["best_val_loss"]
652
+
653
+ print(f"[MINDITrainer] Resumed from step {self.global_step} "
654
+ f"(val_loss={self.best_val_loss:.4f})")
655
+
656
+
657
+ # ── Test block ────────────────────────────────────────────────────────
658
+ if __name__ == "__main__":
659
+ print("=" * 60)
660
+ print(" MINDI 1.5 β€” Trainer Test")
661
+ print("=" * 60)
662
+ print()
663
+
664
+ # ── Test 1: Config defaults ──────────────────────────────────
665
+ print(" Test 1: TrainingConfig defaults")
666
+ config = TrainingConfig()
667
+ assert config.total_steps == 10000
668
+ assert config.dtype == "bf16"
669
+ assert config.torch_dtype == torch.bfloat16
670
+ assert len(config.phases) == 3
671
+ assert config.gradient_checkpointing is True
672
+ assert config.num_workers == 4
673
+ assert config.pin_memory is True
674
+ assert config.prefetch_factor == 2
675
+ print(f" Total steps: {config.total_steps}")
676
+ print(f" Dtype: {config.dtype}")
677
+ print(f" Phases: {[p.name for p in config.phases]}")
678
+ print(" βœ“ Config defaults correct")
679
+
680
+ # ── Test 2: Phase configs ────────────────────────────────────
681
+ print("\n Test 2: Phase configurations")
682
+ p1, p2, p3 = config.phases
683
+ assert p1.name == "phase1_lora"
684
+ assert p1.batch_size == 16
685
+ assert p1.learning_rate == 2e-4
686
+ assert p1.lora is True and p1.vision_projection is False and p1.fusion is False
687
+
688
+ assert p2.name == "phase2_vision_bridge"
689
+ assert p2.batch_size == 8
690
+ assert p2.learning_rate == 1e-5
691
+ assert p2.lora is False and p2.vision_projection is True and p2.fusion is True
692
+
693
+ assert p3.name == "phase3_all"
694
+ assert p3.batch_size == 12
695
+ assert p3.learning_rate == 5e-5
696
+ assert p3.lora is True and p3.vision_projection is True and p3.fusion is True
697
+ print(" Phase 1: LoRA only, batch=16, lr=2e-4 βœ“")
698
+ print(" Phase 2: Vision bridge, batch=8, lr=1e-5 βœ“")
699
+ print(" Phase 3: All, batch=12, lr=5e-5 βœ“")
700
+
701
+ # ── Test 3: Streaming dataset (if data exists) ───────────────
702
+ print("\n Test 3: StreamingJSONLDataset")
703
+ train_path = config.train_file
704
+ if train_path.exists():
705
+ from transformers import AutoTokenizer
706
+ tok = AutoTokenizer.from_pretrained(
707
+ str(PROJECT_ROOT / "data" / "tokenizer" / "mindi_tokenizer"),
708
+ trust_remote_code=True,
709
+ )
710
+ dataset = StreamingJSONLDataset(
711
+ file_path=train_path,
712
+ tokenizer=tok,
713
+ max_length=512, # small for test
714
+ shuffle_buffer=100,
715
+ )
716
+ count = 0
717
+ for item in dataset:
718
+ assert "input_ids" in item
719
+ assert "attention_mask" in item
720
+ assert "labels" in item
721
+ assert item["input_ids"].shape[0] == 512
722
+ count += 1
723
+ if count >= 5:
724
+ break
725
+ print(f" Streamed {count} examples, shape={item['input_ids'].shape} βœ“")
726
+ else:
727
+ print(f" [SKIP] Train file not found: {train_path}")
728
+
729
+ # ── Test 4: PhaseConfig step ranges ──────────────────────────
730
+ print("\n Test 4: Phase step continuity")
731
+ for i in range(1, len(config.phases)):
732
+ prev = config.phases[i - 1]
733
+ curr = config.phases[i]
734
+ assert prev.end_step == curr.start_step, \
735
+ f"Gap between {prev.name} and {curr.name}"
736
+ print(" All phases are contiguous βœ“")
737
+
738
+ # ── Test 5: Gradient accumulation ────────────────────────────
739
+ print("\n Test 5: Gradient accumulation steps")
740
+ for phase in config.phases:
741
+ assert phase.gradient_accumulation_steps == 4
742
+ print(" All phases: grad_accum=4 βœ“")
743
+
744
+ print("\n βœ“ All trainer tests passed!")
745
+ print("=" * 60)