geolip-bertenstein / cell1_prepare_data.py
AbstractPhil's picture
Create cell1_prepare_data.py
6413528 verified
# ============================================================================
# GEOLIP-BERTENSTEIN STAGE 1: MULTI-EXPERT PRECOMPUTE (REFACTORED)
#
# BERT is the shared text spine.
#
# Pipeline per expert pair:
# 1. Load dataset / stream
# 2. CPU preprocess text + expert input
# 3. GPU encode text with BERT + expert with expert encoder
# 4. Shard-safe Arrow write
# 5. Merge shards -> final save_to_disk
# 6. Unload expert, keep BERT
#
# Experts:
# image : DINOv2-large + COCO-Caption
# audio : Whisper-large + LibriSpeech ASR (streaming)
# protein : ESM-2-650M + Protein2Text-QA (streaming)
# code : CodeBERT-base + CodeSearchNet python
# ============================================================================
import subprocess
import sys
try:
import sympy
_ = sympy.core
except (ImportError, AttributeError):
subprocess.check_call(
[sys.executable, "-m", "pip", "install", "--upgrade", "sympy", "--break-system-packages", "-q"]
)
import gc
import os
import shutil
import time
from dataclasses import dataclass
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import (
Audio,
Dataset as HFDataset,
Features,
Sequence,
Value,
Array2D,
concatenate_datasets,
load_dataset,
load_from_disk,
)
# ============================================================================
# BASE CONFIG
# ============================================================================
@dataclass
class BaseConfig:
cache_dir: str = "/home/claude/geo_cache"
max_text_len: int = 32
device: str = "cuda" if torch.cuda.is_available() else "cpu"
amp_enabled: bool = torch.cuda.is_available()
bert_model_name: str = "google-bert/bert-large-uncased"
bert_hidden_dim: int = 1024
batch_size: int = 256
num_workers: int = 8
prefetch_factor: int = 2
pin_memory: bool = torch.cuda.is_available()
shard_size_default: int = 2048
# expert-specific max samples
max_audio_samples: int = 10000
max_protein_samples: int = 15000
max_code_samples: int = 50000
cleanup_hf_cache_between_experts: bool = True
CFG = BaseConfig()
DEVICE = torch.device(CFG.device)
# ============================================================================
# HF CACHE CLEANUP
# ============================================================================
def cleanup_hf_cache() -> None:
"""Delete HF datasets/hub cache to free disk after encoding an expert."""
hf_cache = os.path.expanduser("~/.cache/huggingface")
for subdir in ["datasets", "hub"]:
p = os.path.join(hf_cache, subdir)
if not os.path.exists(p):
continue
size_gb = 0.0
for dp, _, files in os.walk(p):
for f in files:
fp = os.path.join(dp, f)
try:
size_gb += os.path.getsize(fp)
except OSError:
pass
size_gb /= 1e9
print(f" Cleaning {p} ({size_gb:.1f} GB)...")
shutil.rmtree(p, ignore_errors=True)
os.makedirs(p, exist_ok=True)
def cleanup_cuda() -> None:
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# ============================================================================
# SHARED BERT
# ============================================================================
_bert_tokenizer = None
def get_bert_tokenizer():
global _bert_tokenizer
if _bert_tokenizer is None:
from transformers import BertTokenizer
_bert_tokenizer = BertTokenizer.from_pretrained(CFG.bert_model_name)
return _bert_tokenizer
def load_shared_bert():
from transformers import BertModel
print("Loading shared BERT-large...")
bert = BertModel.from_pretrained(
CFG.bert_model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
).to(DEVICE).eval()
print(" BERT ready.")
return bert
# ============================================================================
# COMMON HELPERS
# ============================================================================
def ensure_dir(path: str) -> None:
os.makedirs(path, exist_ok=True)
def make_loader(ds: Dataset, batch_size: int, num_workers: int) -> DataLoader:
kwargs = dict(
dataset=ds,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=CFG.pin_memory,
persistent_workers=num_workers > 0,
)
if num_workers > 0:
kwargs["prefetch_factor"] = CFG.prefetch_factor
return DataLoader(**kwargs)
def masked_text_tokenize(text: str, tokenizer) -> Tuple[torch.Tensor, torch.Tensor]:
tok = tokenizer(
text,
padding="max_length",
truncation=True,
max_length=CFG.max_text_len,
return_tensors="pt",
)
return tok["input_ids"].squeeze(0), tok["attention_mask"].squeeze(0)
def extract_first_text(sample: Dict[str, Any], keys: List[str]) -> str:
for key in keys:
if key not in sample:
continue
value = sample[key]
if isinstance(value, str):
value = value.strip()
if value:
return value
if isinstance(value, list) and value:
first = value[0]
if isinstance(first, str):
first = first.strip()
if first:
return first
if isinstance(first, dict):
txt = str(first.get("raw", first.get("text", ""))).strip()
if txt:
return txt
txt = str(first).strip()
if txt:
return txt
return ""
# ============================================================================
# SHARD WRITER
# ============================================================================
class ShardWriter:
def __init__(
self,
cache_dir: str,
tag: str,
features: Features,
shard_size: int,
row_keys: List[str],
):
self.cache_dir = cache_dir
self.tag = tag
self.features = features
self.shard_size = shard_size
self.row_keys = row_keys
self.cache_path = os.path.join(cache_dir, tag)
self.shard_root = os.path.join(cache_dir, f"{tag}__shards")
self.rows = {k: [] for k in row_keys}
self.shard_paths: List[str] = []
self.shard_idx = 0
self.n_written = 0
@property
def exists(self) -> bool:
return os.path.exists(self.cache_path)
def add_row(self, row: Dict[str, Any]) -> None:
for k in self.row_keys:
self.rows[k].append(row[k])
if len(self.rows[self.row_keys[0]]) >= self.shard_size:
self.flush()
def flush(self) -> None:
n_rows = len(self.rows[self.row_keys[0]])
if n_rows == 0:
return
ensure_dir(self.shard_root)
shard_path = os.path.join(self.shard_root, f"shard_{self.shard_idx:05d}")
ds = HFDataset.from_dict(self.rows, features=self.features)
ds.save_to_disk(shard_path)
self.shard_paths.append(shard_path)
self.shard_idx += 1
self.n_written += n_rows
self.rows = {k: [] for k in self.row_keys}
def finalize(self) -> str:
self.flush()
print(f" Merging {len(self.shard_paths)} shards...")
merged = concatenate_datasets([load_from_disk(p) for p in self.shard_paths])
merged.save_to_disk(self.cache_path)
print(f" Saved {len(merged)} pairs to {self.cache_path}")
if os.path.exists(self.shard_root):
shutil.rmtree(self.shard_root, ignore_errors=True)
return self.cache_path
# ============================================================================
# MAP-STYLE DATASETS (NON-STREAMING)
# ============================================================================
class ImageTextDataset(Dataset):
def __init__(self, hf_ds, bert_tokenizer, image_processor):
self.ds = hf_ds
self.tok = bert_tokenizer
self.proc = image_processor
self.fallback_shape = (3, 518, 518)
def __len__(self):
return len(self.ds)
def __getitem__(self, idx):
sample = self.ds[idx]
caption = extract_first_text(
sample,
["answer", "caption", "captions", "text", "original_alt_text"],
)
ids, mask = masked_text_tokenize(caption, self.tok)
image = sample.get("image", None)
valid = True
if image is not None and hasattr(image, "convert"):
try:
expert_input = self.proc(
images=image.convert("RGB"),
return_tensors="pt",
)["pixel_values"].squeeze(0)
except Exception:
expert_input = torch.zeros(self.fallback_shape, dtype=torch.float32)
valid = False
else:
expert_input = torch.zeros(self.fallback_shape, dtype=torch.float32)
valid = False
return ids, mask, expert_input, valid
class CodeTextDataset(Dataset):
def __init__(self, hf_ds, bert_tokenizer, code_tokenizer):
self.ds = hf_ds
self.tok = bert_tokenizer
self.code_tok = code_tokenizer
def __len__(self):
return len(self.ds)
def __getitem__(self, idx):
sample = self.ds[idx]
doc = sample.get("func_documentation_string", "")
if not doc or not doc.strip():
doc = str(sample.get("whole_func_string", ""))[:200]
doc = str(doc).strip()[:500]
ids, mask = masked_text_tokenize(doc, self.tok)
code = str(sample.get("func_code_string", sample.get("whole_func_string", ""))).strip()[:512]
valid = len(code) > 5 and len(doc) > 5
if valid:
try:
tok = self.code_tok(
code,
padding="max_length",
truncation=True,
max_length=256,
return_tensors="pt",
)
code_ids = tok["input_ids"].squeeze(0)
code_mask = tok["attention_mask"].squeeze(0)
except Exception:
code_ids = torch.zeros(256, dtype=torch.long)
code_mask = torch.zeros(256, dtype=torch.long)
valid = False
else:
code_ids = torch.zeros(256, dtype=torch.long)
code_mask = torch.zeros(256, dtype=torch.long)
return ids, mask, torch.stack([code_ids, code_mask]), valid
# ============================================================================
# SHARED NON-STREAM ENCODER
# ============================================================================
@torch.no_grad()
def encode_map_dataset(
*,
tag: str,
loader: DataLoader,
bert,
expert_name: str,
expert_hidden_shape: Tuple[int, int],
expert_forward: Callable[[torch.Tensor], torch.Tensor],
shard_size: int,
max_samples: Optional[int] = None,
) -> str:
cache_path = os.path.join(CFG.cache_dir, tag)
if os.path.exists(cache_path):
print(f" Cache exists: {cache_path}")
return cache_path
features = Features({
"text_hidden": Array2D(shape=(CFG.max_text_len, CFG.bert_hidden_dim), dtype="float16"),
"text_mask": Sequence(Value("bool"), length=CFG.max_text_len),
f"{expert_name}_hidden": Array2D(shape=expert_hidden_shape, dtype="float16"),
})
writer = ShardWriter(
cache_dir=CFG.cache_dir,
tag=tag,
features=features,
shard_size=shard_size,
row_keys=["text_hidden", "text_mask", f"{expert_name}_hidden"],
)
t0 = time.time()
n = 0
for batch in loader:
text_ids, text_mask, expert_input, valid = batch
valid_b = valid.bool()
if not valid_b.any():
continue
text_ids = text_ids[valid_b].to(DEVICE, non_blocking=True)
text_mask_gpu = text_mask[valid_b].to(DEVICE, non_blocking=True)
expert_input = expert_input[valid_b].to(DEVICE, non_blocking=True)
text_hidden = bert(
input_ids=text_ids,
attention_mask=text_mask_gpu,
).last_hidden_state.detach().to(dtype=torch.float16).cpu().numpy()
text_mask_np = text_mask_gpu.bool().cpu().numpy()
expert_hidden = expert_forward(expert_input).detach().to(dtype=torch.float16).cpu().numpy()
for i in range(text_hidden.shape[0]):
writer.add_row({
"text_hidden": text_hidden[i],
"text_mask": text_mask_np[i].tolist(),
f"{expert_name}_hidden": expert_hidden[i],
})
n += text_hidden.shape[0]
if n % 1000 < CFG.batch_size or n <= CFG.batch_size:
rate = n / max(time.time() - t0, 1e-6)
print(f" {n}" + (f"/{max_samples}" if max_samples else "") + f" ({rate:.0f}/s)")
if max_samples is not None and n >= max_samples:
break
final_path = writer.finalize()
print(f" Completed {n} samples in {time.time() - t0:.0f}s")
return final_path
# ============================================================================
# STREAMING HELPERS
# ============================================================================
def decode_audio_obj(audio_obj) -> Tuple[np.ndarray, int]:
if hasattr(audio_obj, "get_all_samples"):
samples = audio_obj.get_all_samples()
arr = samples.data.numpy().squeeze()
sr = samples.sample_rate
return arr, sr
if isinstance(audio_obj, dict):
return audio_obj["array"], audio_obj.get("sampling_rate", 16000)
raise TypeError(f"Unsupported audio object type: {type(audio_obj)}")
def stream_librispeech_batches(
stream,
bert_tokenizer,
whisper_processor,
batch_size: int,
) -> Iterable[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
batch_ids = []
batch_masks = []
batch_mels = []
for sample in stream:
text = sample.get("text", sample.get("transcription", ""))
audio_obj = sample.get("audio")
if not text or audio_obj is None:
continue
try:
audio_array, sr = decode_audio_obj(audio_obj)
except Exception:
continue
ids, mask = masked_text_tokenize(str(text), bert_tokenizer)
try:
mel = whisper_processor(
audio_array,
sampling_rate=sr,
return_tensors="pt",
).input_features.squeeze(0)
except Exception:
continue
batch_ids.append(ids)
batch_masks.append(mask)
batch_mels.append(mel)
if len(batch_ids) >= batch_size:
yield (
torch.stack(batch_ids),
torch.stack(batch_masks),
torch.stack(batch_mels),
)
batch_ids, batch_masks, batch_mels = [], [], []
if batch_ids:
yield (
torch.stack(batch_ids),
torch.stack(batch_masks),
torch.stack(batch_mels),
)
def extract_protein_caption(sample: Dict[str, Any]) -> str:
convos = sample.get("conversations", [])
if isinstance(convos, list):
for c in convos:
if isinstance(c, dict) and c.get("from") == "gpt":
v = str(c.get("value", "")).strip()
if v:
return v[:500]
for c in convos:
if isinstance(c, dict) and "value" in c:
v = str(c["value"]).strip()
if v:
return v[:500]
return str(sample.get("protein", "")).strip()[:500]
def stream_protein_batches(
stream,
bert_tokenizer,
esm_tokenizer,
batch_size: int,
) -> Iterable[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
batch_ids = []
batch_masks = []
batch_esm_ids = []
batch_esm_masks = []
for sample in stream:
caption = extract_protein_caption(sample)
seq = str(sample.get("amino_seq", sample.get("protein_sequence", ""))).strip()
if len(caption) < 5 or len(seq) < 5:
continue
ids, mask = masked_text_tokenize(caption, bert_tokenizer)
try:
esm_t = esm_tokenizer(
seq,
padding="max_length",
truncation=True,
max_length=512,
return_tensors="pt",
)
except Exception:
continue
batch_ids.append(ids)
batch_masks.append(mask)
batch_esm_ids.append(esm_t["input_ids"].squeeze(0))
batch_esm_masks.append(esm_t["attention_mask"].squeeze(0))
if len(batch_ids) >= batch_size:
yield (
torch.stack(batch_ids),
torch.stack(batch_masks),
torch.stack(batch_esm_ids),
torch.stack(batch_esm_masks),
)
batch_ids, batch_masks, batch_esm_ids, batch_esm_masks = [], [], [], []
if batch_ids:
yield (
torch.stack(batch_ids),
torch.stack(batch_masks),
torch.stack(batch_esm_ids),
torch.stack(batch_esm_masks),
)
@torch.no_grad()
def encode_streaming_batches(
*,
tag: str,
expert_name: str,
expert_hidden_shape: Tuple[int, int],
batch_iter: Iterable,
bert,
expert_batch_forward: Callable[..., torch.Tensor],
shard_size: int,
row_keys: List[str],
max_samples: Optional[int] = None,
) -> str:
cache_path = os.path.join(CFG.cache_dir, tag)
if os.path.exists(cache_path):
print(f" Cache exists: {cache_path}")
return cache_path
features = Features({
"text_hidden": Array2D(shape=(CFG.max_text_len, CFG.bert_hidden_dim), dtype="float16"),
"text_mask": Sequence(Value("bool"), length=CFG.max_text_len),
f"{expert_name}_hidden": Array2D(shape=expert_hidden_shape, dtype="float16"),
})
writer = ShardWriter(
cache_dir=CFG.cache_dir,
tag=tag,
features=features,
shard_size=shard_size,
row_keys=row_keys,
)
t0 = time.time()
n = 0
for packed in batch_iter:
# first two are always bert ids/masks
text_ids = packed[0].to(DEVICE, non_blocking=True)
text_mask = packed[1].to(DEVICE, non_blocking=True)
text_hidden = bert(
input_ids=text_ids,
attention_mask=text_mask,
).last_hidden_state.detach().to(dtype=torch.float16).cpu().numpy()
text_mask_np = text_mask.bool().cpu().numpy()
expert_hidden = expert_batch_forward(*[p.to(DEVICE, non_blocking=True) for p in packed[2:]])
expert_hidden = expert_hidden.detach().to(dtype=torch.float16).cpu().numpy()
for i in range(text_hidden.shape[0]):
writer.add_row({
"text_hidden": text_hidden[i],
"text_mask": text_mask_np[i].tolist(),
f"{expert_name}_hidden": expert_hidden[i],
})
n += text_hidden.shape[0]
batch_size = text_hidden.shape[0]
if n % 1000 < batch_size or n <= batch_size:
rate = n / max(time.time() - t0, 1e-6)
print(f" {n}" + (f"/{max_samples}" if max_samples else "") + f" ({rate:.0f}/s)")
if max_samples is not None and n >= max_samples:
break
final_path = writer.finalize()
print(f" Completed {n} samples in {time.time() - t0:.0f}s")
return final_path
# ============================================================================
# EXPERT RUNNERS
# ============================================================================
def encode_image_expert(bert, split: str, tag: str, max_samples: Optional[int] = None) -> str:
from transformers import Dinov2Model, AutoImageProcessor
print(f"\n [IMAGE] Loading DINOv2-large + COCO-Caption ({split})...")
dino = Dinov2Model.from_pretrained(
"facebook/dinov2-large",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
).to(DEVICE).eval()
proc = AutoImageProcessor.from_pretrained("facebook/dinov2-large")
tok = get_bert_tokenizer()
hf_ds = load_dataset("lmms-lab/COCO-Caption", split=split)
if max_samples is not None:
hf_ds = hf_ds.select(range(min(max_samples, len(hf_ds))))
print(f" Dataset: {len(hf_ds)} samples")
torch_ds = ImageTextDataset(hf_ds, tok, proc)
loader = make_loader(torch_ds, batch_size=CFG.batch_size, num_workers=CFG.num_workers)
def expert_forward(pixel_values):
return dino(pixel_values=pixel_values).last_hidden_state
path = encode_map_dataset(
tag=tag,
loader=loader,
bert=bert,
expert_name="image",
expert_hidden_shape=(257, 1024),
expert_forward=expert_forward,
shard_size=CFG.shard_size_default,
max_samples=max_samples,
)
del dino, proc, hf_ds, torch_ds, loader
cleanup_cuda()
return path
def encode_code_expert(bert, max_samples: Optional[int] = None) -> str:
from transformers import RobertaModel, RobertaTokenizer
print("\n [CODE] Loading CodeBERT + CodeSearchNet python...")
codebert = RobertaModel.from_pretrained(
"microsoft/codebert-base",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
).to(DEVICE).eval()
code_tok = RobertaTokenizer.from_pretrained("microsoft/codebert-base")
tok = get_bert_tokenizer()
hf_ds = load_dataset("code-search-net/code_search_net", "python", split="train")
if max_samples is not None:
hf_ds = hf_ds.select(range(min(max_samples, len(hf_ds))))
hf_ds = hf_ds.filter(
lambda x: bool(x.get("func_documentation_string", "").strip()),
num_proc=4,
)
print(f" Dataset: {len(hf_ds)} samples (after filtering)")
torch_ds = CodeTextDataset(hf_ds, tok, code_tok)
loader = make_loader(torch_ds, batch_size=CFG.batch_size, num_workers=CFG.num_workers)
def expert_forward(packed):
code_ids = packed[:, 0].long()
code_mask = packed[:, 1].long()
return codebert(input_ids=code_ids, attention_mask=code_mask).last_hidden_state
path = encode_map_dataset(
tag="code_csn",
loader=loader,
bert=bert,
expert_name="code",
expert_hidden_shape=(256, 768),
expert_forward=expert_forward,
shard_size=CFG.shard_size_default,
max_samples=max_samples,
)
del codebert, code_tok, hf_ds, torch_ds, loader
cleanup_cuda()
return path
def encode_audio_expert(bert, max_samples: Optional[int] = None) -> str:
from transformers import WhisperModel, WhisperFeatureExtractor
print("\n [AUDIO] Loading Whisper-large-v3 + LibriSpeech ASR (streaming)...")
whisper = WhisperModel.from_pretrained(
"openai/whisper-large-v3",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
).to(DEVICE).eval()
proc = WhisperFeatureExtractor.from_pretrained("openai/whisper-large-v3")
tok = get_bert_tokenizer()
max_n = max_samples or CFG.max_audio_samples
audio_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# probe shape
probe_stream = load_dataset("openslr/librispeech_asr", "clean", split="train.100", streaming=True)
probe_stream = probe_stream.cast_column("audio", Audio(sampling_rate=16000))
first = next(iter(probe_stream))
arr, sr = decode_audio_obj(first["audio"])
mel = proc(arr, sampling_rate=sr, return_tensors="pt").input_features
mel = mel.to(device=DEVICE, dtype=audio_dtype)
with torch.no_grad():
probe_hidden = whisper.encoder(mel).last_hidden_state
seq_len, hidden_dim = probe_hidden.shape[1], probe_hidden.shape[2]
print(f" Whisper encoder output: ({seq_len}, {hidden_dim})")
del mel, probe_hidden
stream = load_dataset("openslr/librispeech_asr", "clean", split="train.100", streaming=True)
stream = stream.cast_column("audio", Audio(sampling_rate=16000))
batch_iter = stream_librispeech_batches(
stream=stream,
bert_tokenizer=tok,
whisper_processor=proc,
batch_size=16,
)
def expert_batch_forward(mels: torch.Tensor) -> torch.Tensor:
mels = mels.to(dtype=audio_dtype)
return whisper.encoder(mels).last_hidden_state
path = encode_streaming_batches(
tag="audio_librispeech",
expert_name="audio",
expert_hidden_shape=(seq_len, hidden_dim),
batch_iter=batch_iter,
bert=bert,
expert_batch_forward=expert_batch_forward,
shard_size=256, # large hidden size; keep shards smaller
row_keys=["text_hidden", "text_mask", "audio_hidden"],
max_samples=max_n,
)
del whisper, proc
cleanup_cuda()
return path
def encode_protein_expert(bert, max_samples: Optional[int] = None) -> str:
from transformers import EsmModel, EsmTokenizer
print("\n [PROTEIN] Loading ESM-2-650M + Protein2Text-QA (streaming)...")
esm = EsmModel.from_pretrained(
"facebook/esm2_t33_650M_UR50D",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
).to(DEVICE).eval()
esm_tok = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
tok = get_bert_tokenizer()
max_n = max_samples or CFG.max_protein_samples
stream = load_dataset("tumorailab/Protein2Text-QA", split="test", streaming=True)
batch_iter = stream_protein_batches(
stream=stream,
bert_tokenizer=tok,
esm_tokenizer=esm_tok,
batch_size=32,
)
def expert_batch_forward(esm_ids: torch.Tensor, esm_mask: torch.Tensor) -> torch.Tensor:
return esm(input_ids=esm_ids.long(), attention_mask=esm_mask.long()).last_hidden_state
path = encode_streaming_batches(
tag="protein_p2t",
expert_name="protein",
expert_hidden_shape=(512, 1280),
batch_iter=batch_iter,
bert=bert,
expert_batch_forward=expert_batch_forward,
shard_size=512,
row_keys=["text_hidden", "text_mask", "protein_hidden"],
max_samples=max_n,
)
del esm, esm_tok
cleanup_cuda()
return path
# ============================================================================
# MAIN
# ============================================================================
def main():
ensure_dir(CFG.cache_dir)
print("=" * 70)
print("STAGE 1: MULTI-EXPERT PRECOMPUTE")
print("=" * 70)
if torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name()}")
print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
print(f"Cache: {CFG.cache_dir}")
required_tags = [
"image_coco",
"image_coco_test",
"audio_librispeech",
"protein_p2t",
"code_csn",
]
missing = [t for t in required_tags if not os.path.exists(os.path.join(CFG.cache_dir, t))]
if not missing:
print("\nAll caches exist. Nothing to encode.")
bert = None
else:
print(f"\nMissing caches: {missing}")
if CFG.cleanup_hf_cache_between_experts:
cleanup_hf_cache()
bert = load_shared_bert()
paths: Dict[str, Optional[str]] = {}
# IMAGE TRAIN
print(f"\n{'─' * 50}")
print("[1/4] IMAGE — DINOv2 + COCO-Caption")
if os.path.exists(os.path.join(CFG.cache_dir, "image_coco")):
print(" [IMAGE] Cache exists, skipping.")
paths["image"] = os.path.join(CFG.cache_dir, "image_coco")
else:
paths["image"] = encode_image_expert(bert, split="val", tag="image_coco")
if CFG.cleanup_hf_cache_between_experts:
cleanup_hf_cache()
# IMAGE TEST
if os.path.exists(os.path.join(CFG.cache_dir, "image_coco_test")):
print("\n [IMAGE-TEST] Cache exists, skipping.")
paths["image_test"] = os.path.join(CFG.cache_dir, "image_coco_test")
else:
print("\n [IMAGE-TEST] COCO test split...")
paths["image_test"] = encode_image_expert(bert, split="test", tag="image_coco_test")
if CFG.cleanup_hf_cache_between_experts:
cleanup_hf_cache()
# AUDIO
print(f"\n{'─' * 50}")
print("[2/4] AUDIO — Whisper + LibriSpeech")
if os.path.exists(os.path.join(CFG.cache_dir, "audio_librispeech")):
print(" [AUDIO] Cache exists, skipping.")
paths["audio"] = os.path.join(CFG.cache_dir, "audio_librispeech")
else:
try:
paths["audio"] = encode_audio_expert(bert, max_samples=CFG.max_audio_samples)
except Exception as e:
print(f" AUDIO failed: {e}")
paths["audio"] = None
if CFG.cleanup_hf_cache_between_experts:
cleanup_hf_cache()
# PROTEIN
print(f"\n{'─' * 50}")
print("[3/4] PROTEIN — ESM-2 + Protein2Text-QA")
if os.path.exists(os.path.join(CFG.cache_dir, "protein_p2t")):
print(" [PROTEIN] Cache exists, skipping.")
paths["protein"] = os.path.join(CFG.cache_dir, "protein_p2t")
else:
try:
paths["protein"] = encode_protein_expert(bert, max_samples=CFG.max_protein_samples)
except Exception as e:
print(f" PROTEIN failed: {e}")
paths["protein"] = None
if CFG.cleanup_hf_cache_between_experts:
cleanup_hf_cache()
# CODE
print(f"\n{'─' * 50}")
print("[4/4] CODE — CodeBERT + CodeSearchNet Python")
if os.path.exists(os.path.join(CFG.cache_dir, "code_csn")):
print(" [CODE] Cache exists, skipping.")
paths["code"] = os.path.join(CFG.cache_dir, "code_csn")
else:
try:
paths["code"] = encode_code_expert(bert, max_samples=CFG.max_code_samples)
except Exception as e:
print(f" CODE failed: {e}")
paths["code"] = None
if CFG.cleanup_hf_cache_between_experts:
cleanup_hf_cache()
if bert is not None:
del bert
cleanup_cuda()
flickr_path = os.path.join(CFG.cache_dir, "flickr30k")
if os.path.exists(flickr_path):
paths["flickr"] = flickr_path
print(f"\n{'=' * 70}")
print("CACHE SUMMARY")
print(f"{'=' * 70}")
for name, path in paths.items():
if path and os.path.exists(path):
ds = load_from_disk(path)
print(f" {name:15s}: {len(ds):6d} pairs [{path}]")
print("\nReady for Stage 2 multi-expert training.")
print("Done.")
if __name__ == "__main__":
main()