# ============================================================================ # STAGE 1: PRECOMPUTE EMBEDDINGS — DATALOADER PIPELINE (CORRECTED) -> flikr fixed # # Architecture: # HF load_dataset # → custom torch.Dataset (__getitem__ does CPU tokenization + image processing) # → DataLoader (workers do CPU I/O) # → GPU encode # → shard-safe HF Arrow writes # → concatenate shards # → save_to_disk final dataset # ============================================================================ # Fix broken sympy before torch imports it import subprocess import sys try: import sympy _ = sympy.core except (ImportError, AttributeError): print("Fixing sympy...") subprocess.check_call( [sys.executable, "-m", "pip", "install", "--upgrade", "sympy", "--break-system-packages", "-q"] ) print(" sympy upgraded. Restart kernel if needed.") import gc import json import math import os import shutil import time from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple import numpy as np import torch from torch.utils.data import Dataset, DataLoader from datasets import ( Dataset as HFDataset, Features, Sequence, Value, Array2D, concatenate_datasets, load_dataset, load_from_disk, ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ══════════════════════════════════════════════════════════════════ # CONFIG # ══════════════════════════════════════════════════════════════════ @dataclass class Stage1Config: cache_dir: str = "/home/claude/geo_cache" max_text_len: int = 32 batch_size: int = 512 num_workers: int = 8 shard_size: int = 2048 # number of valid encoded samples per shard writer_batch_size: int = 256 # HF internal writer batch size pin_memory: bool = torch.cuda.is_available() prefetch_factor: int = 2 cleanup_shards_after_merge: bool = True print_every: int = 1000 CFG = Stage1Config() # ══════════════════════════════════════════════════════════════════ # HELPERS # ══════════════════════════════════════════════════════════════════ def extract_caption(sample: Dict[str, Any]) -> str: """ Deterministic caption extraction. Keeps your original heuristic, but isolates it for clarity and future replacement. """ for key in ["answer", "caption", "captions", "sentences", "text"]: if key not in sample: continue val = sample[key] if isinstance(val, str): caption = val.strip() if caption: return caption if isinstance(val, list) and val: item = val[0] if isinstance(item, str): caption = item.strip() if caption: return caption if isinstance(item, dict): caption = str(item.get("raw", item.get("text", ""))).strip() if caption: return caption caption = str(item).strip() if caption: return caption return "" def make_dataloader(dataset: Dataset, batch_size: int, num_workers: int = 8, shuffle: bool = False) -> DataLoader: """DataLoader with pinned memory and prefetch.""" kwargs = dict( dataset=dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=CFG.pin_memory, persistent_workers=num_workers > 0, ) if num_workers > 0: kwargs["prefetch_factor"] = CFG.prefetch_factor return DataLoader(**kwargs) def flush_shard( shard_root: str, shard_index: int, features: Features, shard_rows: Dict[str, List[Any]], writer_batch_size: int, ) -> Optional[str]: """ Flush one shard to disk and clear in-memory shard rows. """ n_rows = len(shard_rows["source_idx"]) if n_rows == 0: return None shard_path = os.path.join(shard_root, f"shard_{shard_index:05d}") os.makedirs(shard_root, exist_ok=True) ds = HFDataset.from_dict(shard_rows, features=features) ds.save_to_disk(shard_path) return shard_path def reset_shard_rows() -> Dict[str, List[Any]]: return { "source_idx": [], "text_hidden": [], "text_mask": [], "image_hidden": [], } def write_manifest(path: str, data: Dict[str, Any]) -> None: with open(path, "w") as f: json.dump(data, f, indent=2) # ══════════════════════════════════════════════════════════════════ # TORCH DATASET — workers do tokenization + image processing # ══════════════════════════════════════════════════════════════════ class ImageTextDataset(Dataset): """ Wraps an HF dataset. __getitem__ does ALL CPU work: caption extraction, tokenization, image processing. DataLoader workers call this in parallel. Returns tensors ready for GPU forward, plus source index and validity flag. """ def __init__(self, hf_dataset, tokenizer, image_processor, max_text_len: int): self.ds = hf_dataset self.tok = tokenizer self.proc = image_processor self.max_text_len = max_text_len # Determine expected pixel tensor shape once for invalid fallbacks. # If processor output shape differs in practice, valid samples define the real downstream contract. self.fallback_pixel_shape = self._infer_fallback_pixel_shape() def _infer_fallback_pixel_shape(self) -> Tuple[int, int, int]: # Dinov2 image processor usually produces 3x518x518 for this model family. # We try to infer more cleanly when possible, otherwise fall back. size = getattr(self.proc, "size", None) if isinstance(size, dict): h = size.get("height", size.get("shortest_edge", 518)) w = size.get("width", size.get("shortest_edge", 518)) return (3, int(h), int(w)) return (3, 518, 518) def __len__(self): return len(self.ds) def __getitem__(self, idx): sample = self.ds[idx] # Caption caption = extract_caption(sample) # Tokenize (CPU) tokens = self.tok( caption, padding="max_length", truncation=True, max_length=self.max_text_len, return_tensors="pt", ) input_ids = tokens["input_ids"].squeeze(0) attn_mask = tokens["attention_mask"].squeeze(0) # Image processing (CPU — resize, normalize, to tensor) image = sample.get("image", None) valid = True if image is not None and hasattr(image, "convert"): try: image = image.convert("RGB") pixel_values = self.proc(images=image, return_tensors="pt")["pixel_values"].squeeze(0) except Exception: pixel_values = torch.zeros(self.fallback_pixel_shape, dtype=torch.float32) valid = False else: pixel_values = torch.zeros(self.fallback_pixel_shape, dtype=torch.float32) valid = False return idx, input_ids, attn_mask, pixel_values, valid # ══════════════════════════════════════════════════════════════════ # FULL PIPELINE # ══════════════════════════════════════════════════════════════════ def process_and_cache( dataset_id: str, split: str, max_samples: Optional[int], batch_size: int = 512, num_workers: int = 8, shard_size: int = 2048, tag: Optional[str] = None, bert=None, dino=None, tokenizer=None, processor=None, ) -> str: """ Full pipeline: 1. load_dataset → HF Dataset 2. Wrap in torch Dataset (tokenize + image process in workers) 3. DataLoader → GPU encode 4. Write shard-safe Arrow datasets 5. Concatenate shards → save final dataset """ assert bert is not None assert dino is not None assert tokenizer is not None assert processor is not None tag = tag or f"{dataset_id.replace('/', '_')}_{split}" cache_path = os.path.join(CFG.cache_dir, tag) shard_root = os.path.join(CFG.cache_dir, f"{tag}__shards") manifest_path = os.path.join(CFG.cache_dir, f"{tag}__manifest.json") if os.path.exists(cache_path): print(f" Cache exists: {cache_path}") ds = load_from_disk(cache_path) print(f" {len(ds)} samples cached") return cache_path os.makedirs(CFG.cache_dir, exist_ok=True) print(f"\n Loading {dataset_id} ({split})...") t0 = time.time() hf_ds = load_dataset(dataset_id, split=split) raw_total = len(hf_ds) print(f" Dataset: {raw_total} samples") # Truncate raw source dataset if requested if max_samples is not None and raw_total > max_samples: hf_ds = hf_ds.select(range(max_samples)) print(f" Truncated raw dataset to {len(hf_ds)}") raw_total = len(hf_ds) first = hf_ds[0] print(f" Columns: {list(first.keys())}") torch_ds = ImageTextDataset( hf_dataset=hf_ds, tokenizer=tokenizer, image_processor=processor, max_text_len=CFG.max_text_len, ) loader = make_dataloader( dataset=torch_ds, batch_size=batch_size, num_workers=num_workers, shuffle=False, ) print(" Encoding...") feature_schema: Optional[Features] = None shard_rows = reset_shard_rows() shard_paths: List[str] = [] shard_index = 0 raw_seen = 0 valid_saved = 0 invalid_dropped = 0 for batch in loader: source_idx, input_ids, attn_mask, pixel_values, valid = batch batch_raw = int(input_ids.shape[0]) raw_seen += batch_raw valid_b = valid.bool() invalid_dropped += int((~valid_b).sum().item()) if not valid_b.any(): if raw_seen % CFG.print_every < batch_raw or raw_seen <= batch_raw: rate = raw_seen / max(time.time() - t0, 1e-6) print(f" raw={raw_seen}/{raw_total} valid={valid_saved} invalid={invalid_dropped} ({rate:.0f} raw/s)") continue source_idx_v = source_idx[valid_b] input_ids_v = input_ids[valid_b].to(device, non_blocking=True) attn_mask_v = attn_mask[valid_b].to(device, non_blocking=True) pixel_values_v = pixel_values[valid_b].to(device, non_blocking=True) with torch.no_grad(): if torch.cuda.is_available(): with torch.amp.autocast("cuda", enabled=True): text_h = bert(input_ids=input_ids_v, attention_mask=attn_mask_v).last_hidden_state image_h = dino(pixel_values=pixel_values_v).last_hidden_state else: text_h = bert(input_ids=input_ids_v, attention_mask=attn_mask_v).last_hidden_state image_h = dino(pixel_values=pixel_values_v).last_hidden_state text_h = text_h.detach().to(dtype=torch.float16).cpu().numpy() text_m = attn_mask_v.bool().cpu().numpy() image_h = image_h.detach().to(dtype=torch.float16).cpu().numpy() source_idx_np = source_idx_v.cpu().numpy().astype(np.int64) # Establish explicit schema once from the first valid encoded batch. if feature_schema is None: text_shape = tuple(text_h.shape[1:]) image_shape = tuple(image_h.shape[1:]) feature_schema = Features({ "source_idx": Value("int64"), "text_hidden": Array2D(shape=text_shape, dtype="float16"), "text_mask": Sequence(Value("bool"), length=text_shape[0]), "image_hidden": Array2D(shape=image_shape, dtype="float16"), }) print(f" Feature schema:") print(f" text_hidden: {text_shape} float16") print(f" text_mask: ({text_shape[0]},) bool") print(f" image_hidden:{image_shape} float16") # Accumulate only the current shard in memory. for i in range(text_h.shape[0]): shard_rows["source_idx"].append(int(source_idx_np[i])) shard_rows["text_hidden"].append(text_h[i]) shard_rows["text_mask"].append(text_m[i].tolist()) shard_rows["image_hidden"].append(image_h[i]) valid_saved += int(text_h.shape[0]) if valid_saved % CFG.print_every < text_h.shape[0] or valid_saved <= text_h.shape[0]: rate = raw_seen / max(time.time() - t0, 1e-6) print( f" raw={raw_seen}/{raw_total} valid={valid_saved} " f"invalid={invalid_dropped} ({rate:.0f} raw/s)" ) if len(shard_rows["source_idx"]) >= shard_size: shard_path = flush_shard( shard_root=shard_root, shard_index=shard_index, features=feature_schema, shard_rows=shard_rows, writer_batch_size=CFG.writer_batch_size, ) if shard_path is not None: shard_paths.append(shard_path) print(f" Flushed shard {shard_index:05d} ({len(load_from_disk(shard_path))} rows)") shard_index += 1 shard_rows = reset_shard_rows() # Flush tail shard if feature_schema is None: raise RuntimeError("No valid samples were encoded. Cannot build cache.") tail_path = flush_shard( shard_root=shard_root, shard_index=shard_index, features=feature_schema, shard_rows=shard_rows, writer_batch_size=CFG.writer_batch_size, ) if tail_path is not None: shard_paths.append(tail_path) print(f" Flushed shard {shard_index:05d} ({len(load_from_disk(tail_path))} rows)") # Merge shards into final dataset print(" Merging shards...") shard_datasets = [load_from_disk(p) for p in shard_paths] result_ds = concatenate_datasets(shard_datasets) result_ds.save_to_disk(cache_path) elapsed = time.time() - t0 print(f" Saved {len(result_ds)} samples to {cache_path} ({elapsed:.0f}s)") manifest = { "dataset_id": dataset_id, "split": split, "tag": tag, "cache_path": cache_path, "raw_total_considered": raw_total, "raw_seen": raw_seen, "valid_saved": valid_saved, "invalid_dropped": invalid_dropped, "invalid_rate": (invalid_dropped / raw_seen) if raw_seen > 0 else 0.0, "num_shards": len(shard_paths), "feature_schema": { "text_hidden_shape": list(feature_schema["text_hidden"].shape), "text_mask_len": feature_schema["text_mask"].length, "image_hidden_shape": list(feature_schema["image_hidden"].shape), }, "elapsed_sec": elapsed, } write_manifest(manifest_path, manifest) print(f" Wrote manifest: {manifest_path}") # Cleanup shard directories if requested if CFG.cleanup_shards_after_merge and os.path.exists(shard_root): shutil.rmtree(shard_root, ignore_errors=True) print(f" Removed temporary shards: {shard_root}") # Free RAM/VRAM between datasets del result_ds del shard_datasets gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() return cache_path # ══════════════════════════════════════════════════════════════════ # MAIN # ══════════════════════════════════════════════════════════════════ if __name__ == "__main__": os.makedirs(CFG.cache_dir, exist_ok=True) print("=" * 70) print("STAGE 1: PRECOMPUTE EMBEDDINGS") print("=" * 70) if torch.cuda.is_available(): print(f"GPU: {torch.cuda.get_device_name()}") print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") print(f"Cache dir: {CFG.cache_dir}") # Load encoders ONCE — shared across all datasets print("\nLoading encoders...") from transformers import BertModel, BertTokenizer, Dinov2Model, AutoImageProcessor tokenizer = BertTokenizer.from_pretrained("google-bert/bert-large-uncased") processor = AutoImageProcessor.from_pretrained("facebook/dinov2-large") bert = BertModel.from_pretrained( "google-bert/bert-large-uncased", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, ).to(device).eval() dino = Dinov2Model.from_pretrained( "facebook/dinov2-large", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, ).to(device).eval() print(" Encoders ready.") paths = {} # ── COCO val — FULL ── print(f"\n{'─' * 50}") print("[1/3] COCO-Caption val (training) — FULL") paths["coco_val"] = process_and_cache( dataset_id="lmms-lab/COCO-Caption", split="val", max_samples=None, batch_size=CFG.batch_size, num_workers=CFG.num_workers, shard_size=CFG.shard_size, tag="coco_val", bert=bert, dino=dino, tokenizer=tokenizer, processor=processor, ) # ── COCO test — FULL ── print(f"\n{'─' * 50}") print("[2/3] COCO-Caption test (held-out) — FULL") paths["coco_test"] = process_and_cache( dataset_id="lmms-lab/COCO-Caption", split="test", max_samples=None, batch_size=CFG.batch_size, num_workers=CFG.num_workers, shard_size=CFG.shard_size, tag="coco_test", bert=bert, dino=dino, tokenizer=tokenizer, processor=processor, ) # ── Flickr30k — FULL ── print(f"\n{'─' * 50}") print("[3/3] Flickr30k (cross-dataset) — FULL") try: paths["flickr"] = process_and_cache( dataset_id="Mozilla/flickr30k-transformed-captions", split="test", max_samples=None, batch_size=CFG.batch_size, num_workers=CFG.num_workers, shard_size=CFG.shard_size, tag="flickr30k", bert=bert, dino=dino, tokenizer=tokenizer, processor=processor, ) except Exception as e: print(f" Flickr30k failed: {e}") paths["flickr"] = None # Unload del bert, dino, tokenizer, processor gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() # Summary print(f"\n{'=' * 70}") print("CACHE SUMMARY") print(f"{'=' * 70}") for name, path in paths.items(): if path and os.path.exists(path): ds = load_from_disk(path) print(f" {name:15s}: {len(ds):6d} samples [{path}]") print(f"\n Stage 2 usage:") print(f' ds = load_from_disk("{CFG.cache_dir}/coco_val").with_format("torch")') print(f' loader = DataLoader(ds, batch_size=64, num_workers=4)') print("\nDone.")