Faaz commited on
Commit
672896a
·
1 Parent(s): 4e9835e

Add WebSight vision data pipeline: download script, image-aware data loader, phase data routing

Browse files
configs/training_config.yaml CHANGED
@@ -68,8 +68,14 @@ training:
68
 
69
  # ── Data ───────────────────────────────────────────────────────
70
  data:
 
71
  train_file: "data/processed/train.jsonl" # 4.18GB, 1,304,486 examples
72
  val_file: "data/processed/val.jsonl" # 0.23GB, 72,471 examples
 
 
 
 
 
73
  max_length: 4096
74
  shuffle_buffer: 10000 # Streaming shuffle buffer size
75
  num_workers: 4 # DataLoader workers
 
68
 
69
  # ── Data ───────────────────────────────────────────────────────
70
  data:
71
+ # Text-only code data (Phase 1 + Phase 3)
72
  train_file: "data/processed/train.jsonl" # 4.18GB, 1,304,486 examples
73
  val_file: "data/processed/val.jsonl" # 0.23GB, 72,471 examples
74
+
75
+ # Vision+code data — WebSight UI screenshots (Phase 2 + Phase 3)
76
+ vision_train_file: "data/websight/train.jsonl"
77
+ vision_val_file: "data/websight/val.jsonl"
78
+
79
  max_length: 4096
80
  shuffle_buffer: 10000 # Streaming shuffle buffer size
81
  num_workers: 4 # DataLoader workers
scripts/download_websight.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ MINDI 1.5 Vision-Coder — Download WebSight v0.2 Subset
4
+
5
+ Downloads UI screenshot + HTML/CSS code pairs from HuggingFaceM4/WebSight.
6
+ Saves images to data/websight/images/ and creates data/websight/train.jsonl
7
+ and data/websight/val.jsonl with the MINDI training format.
8
+
9
+ Usage:
10
+ python3 scripts/download_websight.py --num_train 50000 --num_val 2500
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import argparse
16
+ import json
17
+ import os
18
+ import sys
19
+ from pathlib import Path
20
+
21
+ # Add project root to path
22
+ PROJECT_ROOT = Path(__file__).resolve().parent.parent
23
+ sys.path.insert(0, str(PROJECT_ROOT))
24
+
25
+
26
+ def main():
27
+ parser = argparse.ArgumentParser(description="Download WebSight dataset subset")
28
+ parser.add_argument("--num_train", type=int, default=50000,
29
+ help="Number of training examples (default: 50000)")
30
+ parser.add_argument("--num_val", type=int, default=2500,
31
+ help="Number of validation examples (default: 2500)")
32
+ parser.add_argument("--output_dir", type=str, default="data/websight",
33
+ help="Output directory")
34
+ parser.add_argument("--version", type=str, default="v0.2",
35
+ help="WebSight version (v0.1 or v0.2)")
36
+ args = parser.parse_args()
37
+
38
+ total = args.num_train + args.num_val
39
+ output_dir = Path(args.output_dir)
40
+ images_dir = output_dir / "images"
41
+ images_dir.mkdir(parents=True, exist_ok=True)
42
+
43
+ print("=" * 60)
44
+ print(" MINDI 1.5 — WebSight Dataset Download")
45
+ print("=" * 60)
46
+ print(f" Version: {args.version}")
47
+ print(f" Train: {args.num_train:,}")
48
+ print(f" Val: {args.num_val:,}")
49
+ print(f" Output: {output_dir}")
50
+ print()
51
+
52
+ # Load dataset with streaming to avoid downloading everything
53
+ print("[1/3] Loading WebSight dataset (streaming) ...")
54
+ from datasets import load_dataset
55
+
56
+ ds = load_dataset(
57
+ "HuggingFaceM4/WebSight",
58
+ args.version,
59
+ split="train",
60
+ streaming=True,
61
+ token=os.environ.get("HF_TOKEN"),
62
+ )
63
+
64
+ # Process examples
65
+ print(f"[2/3] Downloading {total:,} examples ...")
66
+ train_path = output_dir / "train.jsonl"
67
+ val_path = output_dir / "val.jsonl"
68
+
69
+ train_f = open(train_path, "w", encoding="utf-8")
70
+ val_f = open(val_path, "w", encoding="utf-8")
71
+
72
+ count = 0
73
+ for i, example in enumerate(ds):
74
+ if i >= total:
75
+ break
76
+
77
+ # Extract image and code
78
+ image = example.get("image")
79
+ code = example.get("text", "")
80
+
81
+ if image is None or not code.strip():
82
+ continue
83
+
84
+ # Save image
85
+ img_filename = f"ws_{i:07d}.jpg"
86
+ img_path = images_dir / img_filename
87
+ image.save(str(img_path), "JPEG", quality=85)
88
+
89
+ # Create MINDI-format training example
90
+ entry = {
91
+ "id": f"websight_{i:07d}",
92
+ "type": "vision_code",
93
+ "source": "websight_v0.2",
94
+ "image_path": f"data/websight/images/{img_filename}",
95
+ "messages": [
96
+ {
97
+ "role": "system",
98
+ "content": "You are MINDI 1.5 Vision-Coder, a specialized AI for understanding UI screenshots and generating accurate HTML/CSS code."
99
+ },
100
+ {
101
+ "role": "user",
102
+ "content": "<|vision_start|><|vision_end|>\nGenerate the HTML/CSS code for this UI screenshot."
103
+ },
104
+ {
105
+ "role": "assistant",
106
+ "content": f"<|think_start|>I'll analyze the UI layout and generate the corresponding code.<|think_end|>\n<|code_start|>\n{code.strip()}\n<|code_end|>"
107
+ }
108
+ ],
109
+ "metadata": {
110
+ "dataset": "websight",
111
+ "version": args.version,
112
+ }
113
+ }
114
+
115
+ # Split: first num_train → train, rest → val
116
+ if count < args.num_train:
117
+ train_f.write(json.dumps(entry, ensure_ascii=False) + "\n")
118
+ else:
119
+ val_f.write(json.dumps(entry, ensure_ascii=False) + "\n")
120
+
121
+ count += 1
122
+ if count % 1000 == 0:
123
+ print(f" {count:,}/{total:,} downloaded ...")
124
+
125
+ train_f.close()
126
+ val_f.close()
127
+
128
+ # Stats
129
+ train_count = min(count, args.num_train)
130
+ val_count = max(0, count - args.num_train)
131
+
132
+ print(f"\n[3/3] Done!")
133
+ print(f" Train: {train_count:,} examples → {train_path}")
134
+ print(f" Val: {val_count:,} examples → {val_path}")
135
+ print(f" Images: {images_dir}")
136
+ print(f" Disk: ", end="")
137
+ os.system(f"du -sh {output_dir}")
138
+
139
+
140
+ if __name__ == "__main__":
141
+ main()
scripts/train.py CHANGED
@@ -84,12 +84,12 @@ def build_training_config(raw: dict, dry_run: bool = False):
84
  # Build phase configs from YAML
85
  phases = []
86
  phase_defs = [
87
- ("phase1", "phase1_lora", True, False, False),
88
- ("phase2", "phase2_vision_bridge", False, True, True),
89
- ("phase3", "phase3_all", True, True, True),
90
  ]
91
  cumulative_step = 0
92
- for key, name, lora, vision, fusion in phase_defs:
93
  pcfg = training.get(key, {})
94
  steps = pcfg.get("steps", 2500)
95
  if dry_run:
@@ -106,12 +106,15 @@ def build_training_config(raw: dict, dry_run: bool = False):
106
  lora=lora,
107
  vision_projection=vision,
108
  fusion=fusion,
 
109
  ))
110
  cumulative_step = end
111
 
112
  config = TrainingConfig(
113
  train_file=PROJECT_ROOT / data.get("train_file", "data/processed/train.jsonl"),
114
  val_file=PROJECT_ROOT / data.get("val_file", "data/processed/val.jsonl"),
 
 
115
  output_dir=PROJECT_ROOT / output.get("checkpoint_dir", "checkpoints/training"),
116
  log_dir=PROJECT_ROOT / logging_cfg.get("log_dir", "logs/training"),
117
  max_seq_length=data.get("max_length", 4096),
 
84
  # Build phase configs from YAML
85
  phases = []
86
  phase_defs = [
87
+ ("phase1", "phase1_lora", True, False, False, "text"),
88
+ ("phase2", "phase2_vision_bridge", False, True, True, "vision"),
89
+ ("phase3", "phase3_all", True, True, True, "mixed"),
90
  ]
91
  cumulative_step = 0
92
+ for key, name, lora, vision, fusion, data_type in phase_defs:
93
  pcfg = training.get(key, {})
94
  steps = pcfg.get("steps", 2500)
95
  if dry_run:
 
106
  lora=lora,
107
  vision_projection=vision,
108
  fusion=fusion,
109
+ data_type=data_type,
110
  ))
111
  cumulative_step = end
112
 
113
  config = TrainingConfig(
114
  train_file=PROJECT_ROOT / data.get("train_file", "data/processed/train.jsonl"),
115
  val_file=PROJECT_ROOT / data.get("val_file", "data/processed/val.jsonl"),
116
+ vision_train_file=PROJECT_ROOT / data.get("vision_train_file", "data/websight/train.jsonl"),
117
+ vision_val_file=PROJECT_ROOT / data.get("vision_val_file", "data/websight/val.jsonl"),
118
  output_dir=PROJECT_ROOT / output.get("checkpoint_dir", "checkpoints/training"),
119
  log_dir=PROJECT_ROOT / logging_cfg.get("log_dir", "logs/training"),
120
  max_seq_length=data.get("max_length", 4096),
src/training/mindi_trainer.py CHANGED
@@ -28,6 +28,7 @@ from typing import Any, Iterator, Optional
28
 
29
  import torch
30
  import torch.nn as nn
 
31
  from torch.optim import AdamW
32
  from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
33
  from torch.utils.data import DataLoader, IterableDataset
@@ -50,6 +51,8 @@ class PhaseConfig:
50
  lora: bool = False
51
  vision_projection: bool = False
52
  fusion: bool = False
 
 
53
 
54
 
55
  @dataclass
@@ -59,6 +62,8 @@ class TrainingConfig:
59
  # Data paths
60
  train_file: Path = field(default_factory=lambda: PROJECT_ROOT / "data" / "processed" / "train.jsonl")
61
  val_file: Path = field(default_factory=lambda: PROJECT_ROOT / "data" / "processed" / "val.jsonl")
 
 
62
 
63
  # Output
64
  output_dir: Path = field(default_factory=lambda: PROJECT_ROOT / "checkpoints" / "training")
@@ -93,18 +98,21 @@ class TrainingConfig:
93
  start_step=0, end_step=5000,
94
  learning_rate=2e-4, batch_size=16,
95
  lora=True, vision_projection=False, fusion=False,
 
96
  ),
97
  PhaseConfig(
98
  name="phase2_vision_bridge",
99
  start_step=5000, end_step=7500,
100
  learning_rate=1e-5, batch_size=8,
101
  lora=False, vision_projection=True, fusion=True,
 
102
  ),
103
  PhaseConfig(
104
  name="phase3_all",
105
  start_step=7500, end_step=10000,
106
  learning_rate=5e-5, batch_size=12,
107
  lora=True, vision_projection=True, fusion=True,
 
108
  ),
109
  ])
110
 
@@ -123,9 +131,11 @@ class StreamingJSONLDataset(IterableDataset):
123
  """
124
  Streams JSONL training data from disk line by line.
125
  Tokenizes on-the-fly to avoid loading 4+ GB into RAM.
 
126
 
127
  Expected JSONL format:
128
  {"id": "...", "type": "...", "source": "...",
 
129
  "messages": [{"role": "system", "content": "..."},
130
  {"role": "user", "content": "..."},
131
  {"role": "assistant", "content": "..."}],
@@ -139,12 +149,14 @@ class StreamingJSONLDataset(IterableDataset):
139
  max_length: int = 8192,
140
  shuffle_buffer: int = 10000,
141
  seed: int = 42,
 
142
  ) -> None:
143
  self.file_path = Path(file_path)
144
  self.tokenizer = tokenizer
145
  self.max_length = max_length
146
  self.shuffle_buffer = shuffle_buffer
147
  self.seed = seed
 
148
 
149
  if not self.file_path.exists():
150
  raise FileNotFoundError(f"Training data not found: {self.file_path}")
@@ -212,7 +224,18 @@ class StreamingJSONLDataset(IterableDataset):
212
  rng.shuffle(buffer)
213
  yield from buffer
214
 
215
- def __iter__(self) -> Iterator[dict[str, torch.Tensor]]:
 
 
 
 
 
 
 
 
 
 
 
216
  for example in self._shuffled_iterator():
217
  messages = example.get("messages", [])
218
  if not messages:
@@ -220,6 +243,12 @@ class StreamingJSONLDataset(IterableDataset):
220
  text = self._format_messages(messages)
221
  tokenized = self._tokenize(text)
222
  if tokenized is not None:
 
 
 
 
 
 
223
  yield tokenized
224
 
225
  def count_lines(self) -> int:
@@ -342,6 +371,17 @@ class MINDITrainer:
342
  shuffle_buffer=shuffle_buffer,
343
  seed=self.config.seed,
344
  )
 
 
 
 
 
 
 
 
 
 
 
345
  return DataLoader(
346
  dataset,
347
  batch_size=batch_size,
@@ -349,6 +389,7 @@ class MINDITrainer:
349
  pin_memory=self.config.pin_memory,
350
  prefetch_factor=self.config.prefetch_factor if self.config.num_workers > 0 else None,
351
  drop_last=True,
 
352
  )
353
 
354
  def _log_metrics(self, metrics: dict) -> None:
@@ -380,12 +421,20 @@ class MINDITrainer:
380
  input_ids = batch["input_ids"].to(self.device)
381
  attention_mask = batch["attention_mask"].to(self.device)
382
  labels = batch["labels"].to(self.device)
 
 
 
 
 
 
 
383
 
384
  with torch.autocast(device_type="cuda", dtype=self.amp_dtype, enabled=self.use_amp):
385
  result = self.model(
386
  input_ids=input_ids,
387
  attention_mask=attention_mask,
388
  labels=labels,
 
389
  )
390
 
391
  if result["loss"] is not None:
@@ -433,6 +482,7 @@ class MINDITrainer:
433
  print(f" LR: {phase.learning_rate} | Batch: {phase.batch_size}")
434
  print(f" Components: LoRA={phase.lora}, Vision={phase.vision_projection}, "
435
  f"Fusion={phase.fusion}")
 
436
  print("=" * 60)
437
 
438
  # Set trainable components
@@ -446,12 +496,21 @@ class MINDITrainer:
446
  optimizer = self._build_optimizer(phase)
447
  scheduler = self._build_scheduler(optimizer, phase)
448
 
 
 
 
 
 
 
 
 
 
449
  # Build data loaders
450
  train_loader = self._build_dataloader(
451
- self.config.train_file, phase.batch_size
452
  )
453
  val_loader = self._build_dataloader(
454
- self.config.val_file, batch_size=max(phase.batch_size // 2, 1),
455
  shuffle_buffer=1000,
456
  )
457
 
@@ -475,6 +534,15 @@ class MINDITrainer:
475
  input_ids = batch["input_ids"].to(self.device)
476
  attention_mask = batch["attention_mask"].to(self.device)
477
  labels = batch["labels"].to(self.device)
 
 
 
 
 
 
 
 
 
478
 
479
  # Forward pass with mixed precision
480
  with torch.autocast(device_type="cuda", dtype=self.amp_dtype, enabled=self.use_amp):
@@ -482,6 +550,7 @@ class MINDITrainer:
482
  input_ids=input_ids,
483
  attention_mask=attention_mask,
484
  labels=labels,
 
485
  )
486
  loss = result["loss"]
487
 
 
28
 
29
  import torch
30
  import torch.nn as nn
31
+ from PIL import Image
32
  from torch.optim import AdamW
33
  from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
34
  from torch.utils.data import DataLoader, IterableDataset
 
51
  lora: bool = False
52
  vision_projection: bool = False
53
  fusion: bool = False
54
+ # Data type: "text" for code-only, "vision" for image+code, "mixed" for both
55
+ data_type: str = "text"
56
 
57
 
58
  @dataclass
 
62
  # Data paths
63
  train_file: Path = field(default_factory=lambda: PROJECT_ROOT / "data" / "processed" / "train.jsonl")
64
  val_file: Path = field(default_factory=lambda: PROJECT_ROOT / "data" / "processed" / "val.jsonl")
65
+ vision_train_file: Path = field(default_factory=lambda: PROJECT_ROOT / "data" / "websight" / "train.jsonl")
66
+ vision_val_file: Path = field(default_factory=lambda: PROJECT_ROOT / "data" / "websight" / "val.jsonl")
67
 
68
  # Output
69
  output_dir: Path = field(default_factory=lambda: PROJECT_ROOT / "checkpoints" / "training")
 
98
  start_step=0, end_step=5000,
99
  learning_rate=2e-4, batch_size=16,
100
  lora=True, vision_projection=False, fusion=False,
101
+ data_type="text",
102
  ),
103
  PhaseConfig(
104
  name="phase2_vision_bridge",
105
  start_step=5000, end_step=7500,
106
  learning_rate=1e-5, batch_size=8,
107
  lora=False, vision_projection=True, fusion=True,
108
+ data_type="vision",
109
  ),
110
  PhaseConfig(
111
  name="phase3_all",
112
  start_step=7500, end_step=10000,
113
  learning_rate=5e-5, batch_size=12,
114
  lora=True, vision_projection=True, fusion=True,
115
+ data_type="mixed",
116
  ),
117
  ])
118
 
 
131
  """
132
  Streams JSONL training data from disk line by line.
133
  Tokenizes on-the-fly to avoid loading 4+ GB into RAM.
134
+ Supports optional image loading for vision-code pairs.
135
 
136
  Expected JSONL format:
137
  {"id": "...", "type": "...", "source": "...",
138
+ "image_path": "data/websight/images/ws_0000001.jpg", (optional)
139
  "messages": [{"role": "system", "content": "..."},
140
  {"role": "user", "content": "..."},
141
  {"role": "assistant", "content": "..."}],
 
149
  max_length: int = 8192,
150
  shuffle_buffer: int = 10000,
151
  seed: int = 42,
152
+ project_root: Optional[Path] = None,
153
  ) -> None:
154
  self.file_path = Path(file_path)
155
  self.tokenizer = tokenizer
156
  self.max_length = max_length
157
  self.shuffle_buffer = shuffle_buffer
158
  self.seed = seed
159
+ self.project_root = Path(project_root) if project_root else PROJECT_ROOT
160
 
161
  if not self.file_path.exists():
162
  raise FileNotFoundError(f"Training data not found: {self.file_path}")
 
224
  rng.shuffle(buffer)
225
  yield from buffer
226
 
227
+ def _load_image(self, image_path: str) -> Optional[Image.Image]:
228
+ """Load image from a relative path. Returns None if missing/corrupt."""
229
+ try:
230
+ full_path = self.project_root / image_path
231
+ if full_path.exists():
232
+ img = Image.open(str(full_path)).convert("RGB")
233
+ return img
234
+ except Exception:
235
+ pass
236
+ return None
237
+
238
+ def __iter__(self) -> Iterator[dict[str, Any]]:
239
  for example in self._shuffled_iterator():
240
  messages = example.get("messages", [])
241
  if not messages:
 
243
  text = self._format_messages(messages)
244
  tokenized = self._tokenize(text)
245
  if tokenized is not None:
246
+ # Load image if path present
247
+ image_path = example.get("image_path")
248
+ if image_path:
249
+ tokenized["image"] = self._load_image(image_path)
250
+ else:
251
+ tokenized["image"] = None
252
  yield tokenized
253
 
254
  def count_lines(self) -> int:
 
371
  shuffle_buffer=shuffle_buffer,
372
  seed=self.config.seed,
373
  )
374
+
375
+ def _collate_fn(batch):
376
+ """Custom collate: stack tensors, keep images as list."""
377
+ collated = {
378
+ "input_ids": torch.stack([b["input_ids"] for b in batch]),
379
+ "attention_mask": torch.stack([b["attention_mask"] for b in batch]),
380
+ "labels": torch.stack([b["labels"] for b in batch]),
381
+ "images": [b.get("image") for b in batch],
382
+ }
383
+ return collated
384
+
385
  return DataLoader(
386
  dataset,
387
  batch_size=batch_size,
 
389
  pin_memory=self.config.pin_memory,
390
  prefetch_factor=self.config.prefetch_factor if self.config.num_workers > 0 else None,
391
  drop_last=True,
392
+ collate_fn=_collate_fn,
393
  )
394
 
395
  def _log_metrics(self, metrics: dict) -> None:
 
421
  input_ids = batch["input_ids"].to(self.device)
422
  attention_mask = batch["attention_mask"].to(self.device)
423
  labels = batch["labels"].to(self.device)
424
+ images = batch.get("images")
425
+ image = None
426
+ if images:
427
+ for img in images:
428
+ if img is not None:
429
+ image = img
430
+ break
431
 
432
  with torch.autocast(device_type="cuda", dtype=self.amp_dtype, enabled=self.use_amp):
433
  result = self.model(
434
  input_ids=input_ids,
435
  attention_mask=attention_mask,
436
  labels=labels,
437
+ image=image,
438
  )
439
 
440
  if result["loss"] is not None:
 
482
  print(f" LR: {phase.learning_rate} | Batch: {phase.batch_size}")
483
  print(f" Components: LoRA={phase.lora}, Vision={phase.vision_projection}, "
484
  f"Fusion={phase.fusion}")
485
+ print(f" Data: {phase.data_type}")
486
  print("=" * 60)
487
 
488
  # Set trainable components
 
496
  optimizer = self._build_optimizer(phase)
497
  scheduler = self._build_scheduler(optimizer, phase)
498
 
499
+ # Select data files based on phase data_type
500
+ if phase.data_type == "vision":
501
+ train_file = self.config.vision_train_file
502
+ val_file = self.config.vision_val_file
503
+ else:
504
+ # "text" or "mixed" — use main data (mixed has images inline)
505
+ train_file = self.config.train_file
506
+ val_file = self.config.val_file
507
+
508
  # Build data loaders
509
  train_loader = self._build_dataloader(
510
+ train_file, phase.batch_size
511
  )
512
  val_loader = self._build_dataloader(
513
+ val_file, batch_size=max(phase.batch_size // 2, 1),
514
  shuffle_buffer=1000,
515
  )
516
 
 
534
  input_ids = batch["input_ids"].to(self.device)
535
  attention_mask = batch["attention_mask"].to(self.device)
536
  labels = batch["labels"].to(self.device)
537
+ images = batch.get("images") # list of PIL Images or Nones
538
+
539
+ # Pick first non-None image in batch (model processes one image at a time)
540
+ image = None
541
+ if images:
542
+ for img in images:
543
+ if img is not None:
544
+ image = img
545
+ break
546
 
547
  # Forward pass with mixed precision
548
  with torch.autocast(device_type="cuda", dtype=self.amp_dtype, enabled=self.use_amp):
 
550
  input_ids=input_ids,
551
  attention_mask=attention_mask,
552
  labels=labels,
553
+ image=image,
554
  )
555
  loss = result["loss"]
556