# ============================================================================ # GEOLIP-BERTENSTEIN STAGE 1: MULTI-EXPERT PRECOMPUTE (REFACTORED) # # BERT is the shared text spine. # # Pipeline per expert pair: # 1. Load dataset / stream # 2. CPU preprocess text + expert input # 3. GPU encode text with BERT + expert with expert encoder # 4. Shard-safe Arrow write # 5. Merge shards -> final save_to_disk # 6. Unload expert, keep BERT # # Experts: # image : DINOv2-large + COCO-Caption # audio : Whisper-large + LibriSpeech ASR (streaming) # protein : ESM-2-650M + Protein2Text-QA (streaming) # code : CodeBERT-base + CodeSearchNet python # ============================================================================ import subprocess import sys try: import sympy _ = sympy.core except (ImportError, AttributeError): subprocess.check_call( [sys.executable, "-m", "pip", "install", "--upgrade", "sympy", "--break-system-packages", "-q"] ) import gc import os import shutil import time from dataclasses import dataclass from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple import numpy as np import torch from torch.utils.data import Dataset, DataLoader from datasets import ( Audio, Dataset as HFDataset, Features, Sequence, Value, Array2D, concatenate_datasets, load_dataset, load_from_disk, ) # ============================================================================ # BASE CONFIG # ============================================================================ @dataclass class BaseConfig: cache_dir: str = "/home/claude/geo_cache" max_text_len: int = 32 device: str = "cuda" if torch.cuda.is_available() else "cpu" amp_enabled: bool = torch.cuda.is_available() bert_model_name: str = "google-bert/bert-large-uncased" bert_hidden_dim: int = 1024 batch_size: int = 256 num_workers: int = 8 prefetch_factor: int = 2 pin_memory: bool = torch.cuda.is_available() shard_size_default: int = 2048 # expert-specific max samples max_audio_samples: int = 10000 max_protein_samples: int = 15000 max_code_samples: int = 50000 cleanup_hf_cache_between_experts: bool = True CFG = BaseConfig() DEVICE = torch.device(CFG.device) # ============================================================================ # HF CACHE CLEANUP # ============================================================================ def cleanup_hf_cache() -> None: """Delete HF datasets/hub cache to free disk after encoding an expert.""" hf_cache = os.path.expanduser("~/.cache/huggingface") for subdir in ["datasets", "hub"]: p = os.path.join(hf_cache, subdir) if not os.path.exists(p): continue size_gb = 0.0 for dp, _, files in os.walk(p): for f in files: fp = os.path.join(dp, f) try: size_gb += os.path.getsize(fp) except OSError: pass size_gb /= 1e9 print(f" Cleaning {p} ({size_gb:.1f} GB)...") shutil.rmtree(p, ignore_errors=True) os.makedirs(p, exist_ok=True) def cleanup_cuda() -> None: gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() # ============================================================================ # SHARED BERT # ============================================================================ _bert_tokenizer = None def get_bert_tokenizer(): global _bert_tokenizer if _bert_tokenizer is None: from transformers import BertTokenizer _bert_tokenizer = BertTokenizer.from_pretrained(CFG.bert_model_name) return _bert_tokenizer def load_shared_bert(): from transformers import BertModel print("Loading shared BERT-large...") bert = BertModel.from_pretrained( CFG.bert_model_name, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, ).to(DEVICE).eval() print(" BERT ready.") return bert # ============================================================================ # COMMON HELPERS # ============================================================================ def ensure_dir(path: str) -> None: os.makedirs(path, exist_ok=True) def make_loader(ds: Dataset, batch_size: int, num_workers: int) -> DataLoader: kwargs = dict( dataset=ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=CFG.pin_memory, persistent_workers=num_workers > 0, ) if num_workers > 0: kwargs["prefetch_factor"] = CFG.prefetch_factor return DataLoader(**kwargs) def masked_text_tokenize(text: str, tokenizer) -> Tuple[torch.Tensor, torch.Tensor]: tok = tokenizer( text, padding="max_length", truncation=True, max_length=CFG.max_text_len, return_tensors="pt", ) return tok["input_ids"].squeeze(0), tok["attention_mask"].squeeze(0) def extract_first_text(sample: Dict[str, Any], keys: List[str]) -> str: for key in keys: if key not in sample: continue value = sample[key] if isinstance(value, str): value = value.strip() if value: return value if isinstance(value, list) and value: first = value[0] if isinstance(first, str): first = first.strip() if first: return first if isinstance(first, dict): txt = str(first.get("raw", first.get("text", ""))).strip() if txt: return txt txt = str(first).strip() if txt: return txt return "" # ============================================================================ # SHARD WRITER # ============================================================================ class ShardWriter: def __init__( self, cache_dir: str, tag: str, features: Features, shard_size: int, row_keys: List[str], ): self.cache_dir = cache_dir self.tag = tag self.features = features self.shard_size = shard_size self.row_keys = row_keys self.cache_path = os.path.join(cache_dir, tag) self.shard_root = os.path.join(cache_dir, f"{tag}__shards") self.rows = {k: [] for k in row_keys} self.shard_paths: List[str] = [] self.shard_idx = 0 self.n_written = 0 @property def exists(self) -> bool: return os.path.exists(self.cache_path) def add_row(self, row: Dict[str, Any]) -> None: for k in self.row_keys: self.rows[k].append(row[k]) if len(self.rows[self.row_keys[0]]) >= self.shard_size: self.flush() def flush(self) -> None: n_rows = len(self.rows[self.row_keys[0]]) if n_rows == 0: return ensure_dir(self.shard_root) shard_path = os.path.join(self.shard_root, f"shard_{self.shard_idx:05d}") ds = HFDataset.from_dict(self.rows, features=self.features) ds.save_to_disk(shard_path) self.shard_paths.append(shard_path) self.shard_idx += 1 self.n_written += n_rows self.rows = {k: [] for k in self.row_keys} def finalize(self) -> str: self.flush() print(f" Merging {len(self.shard_paths)} shards...") merged = concatenate_datasets([load_from_disk(p) for p in self.shard_paths]) merged.save_to_disk(self.cache_path) print(f" Saved {len(merged)} pairs to {self.cache_path}") if os.path.exists(self.shard_root): shutil.rmtree(self.shard_root, ignore_errors=True) return self.cache_path # ============================================================================ # MAP-STYLE DATASETS (NON-STREAMING) # ============================================================================ class ImageTextDataset(Dataset): def __init__(self, hf_ds, bert_tokenizer, image_processor): self.ds = hf_ds self.tok = bert_tokenizer self.proc = image_processor self.fallback_shape = (3, 518, 518) def __len__(self): return len(self.ds) def __getitem__(self, idx): sample = self.ds[idx] caption = extract_first_text( sample, ["answer", "caption", "captions", "text", "original_alt_text"], ) ids, mask = masked_text_tokenize(caption, self.tok) image = sample.get("image", None) valid = True if image is not None and hasattr(image, "convert"): try: expert_input = self.proc( images=image.convert("RGB"), return_tensors="pt", )["pixel_values"].squeeze(0) except Exception: expert_input = torch.zeros(self.fallback_shape, dtype=torch.float32) valid = False else: expert_input = torch.zeros(self.fallback_shape, dtype=torch.float32) valid = False return ids, mask, expert_input, valid class CodeTextDataset(Dataset): def __init__(self, hf_ds, bert_tokenizer, code_tokenizer): self.ds = hf_ds self.tok = bert_tokenizer self.code_tok = code_tokenizer def __len__(self): return len(self.ds) def __getitem__(self, idx): sample = self.ds[idx] doc = sample.get("func_documentation_string", "") if not doc or not doc.strip(): doc = str(sample.get("whole_func_string", ""))[:200] doc = str(doc).strip()[:500] ids, mask = masked_text_tokenize(doc, self.tok) code = str(sample.get("func_code_string", sample.get("whole_func_string", ""))).strip()[:512] valid = len(code) > 5 and len(doc) > 5 if valid: try: tok = self.code_tok( code, padding="max_length", truncation=True, max_length=256, return_tensors="pt", ) code_ids = tok["input_ids"].squeeze(0) code_mask = tok["attention_mask"].squeeze(0) except Exception: code_ids = torch.zeros(256, dtype=torch.long) code_mask = torch.zeros(256, dtype=torch.long) valid = False else: code_ids = torch.zeros(256, dtype=torch.long) code_mask = torch.zeros(256, dtype=torch.long) return ids, mask, torch.stack([code_ids, code_mask]), valid # ============================================================================ # SHARED NON-STREAM ENCODER # ============================================================================ @torch.no_grad() def encode_map_dataset( *, tag: str, loader: DataLoader, bert, expert_name: str, expert_hidden_shape: Tuple[int, int], expert_forward: Callable[[torch.Tensor], torch.Tensor], shard_size: int, max_samples: Optional[int] = None, ) -> str: cache_path = os.path.join(CFG.cache_dir, tag) if os.path.exists(cache_path): print(f" Cache exists: {cache_path}") return cache_path features = Features({ "text_hidden": Array2D(shape=(CFG.max_text_len, CFG.bert_hidden_dim), dtype="float16"), "text_mask": Sequence(Value("bool"), length=CFG.max_text_len), f"{expert_name}_hidden": Array2D(shape=expert_hidden_shape, dtype="float16"), }) writer = ShardWriter( cache_dir=CFG.cache_dir, tag=tag, features=features, shard_size=shard_size, row_keys=["text_hidden", "text_mask", f"{expert_name}_hidden"], ) t0 = time.time() n = 0 for batch in loader: text_ids, text_mask, expert_input, valid = batch valid_b = valid.bool() if not valid_b.any(): continue text_ids = text_ids[valid_b].to(DEVICE, non_blocking=True) text_mask_gpu = text_mask[valid_b].to(DEVICE, non_blocking=True) expert_input = expert_input[valid_b].to(DEVICE, non_blocking=True) text_hidden = bert( input_ids=text_ids, attention_mask=text_mask_gpu, ).last_hidden_state.detach().to(dtype=torch.float16).cpu().numpy() text_mask_np = text_mask_gpu.bool().cpu().numpy() expert_hidden = expert_forward(expert_input).detach().to(dtype=torch.float16).cpu().numpy() for i in range(text_hidden.shape[0]): writer.add_row({ "text_hidden": text_hidden[i], "text_mask": text_mask_np[i].tolist(), f"{expert_name}_hidden": expert_hidden[i], }) n += text_hidden.shape[0] if n % 1000 < CFG.batch_size or n <= CFG.batch_size: rate = n / max(time.time() - t0, 1e-6) print(f" {n}" + (f"/{max_samples}" if max_samples else "") + f" ({rate:.0f}/s)") if max_samples is not None and n >= max_samples: break final_path = writer.finalize() print(f" Completed {n} samples in {time.time() - t0:.0f}s") return final_path # ============================================================================ # STREAMING HELPERS # ============================================================================ def decode_audio_obj(audio_obj) -> Tuple[np.ndarray, int]: if hasattr(audio_obj, "get_all_samples"): samples = audio_obj.get_all_samples() arr = samples.data.numpy().squeeze() sr = samples.sample_rate return arr, sr if isinstance(audio_obj, dict): return audio_obj["array"], audio_obj.get("sampling_rate", 16000) raise TypeError(f"Unsupported audio object type: {type(audio_obj)}") def stream_librispeech_batches( stream, bert_tokenizer, whisper_processor, batch_size: int, ) -> Iterable[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: batch_ids = [] batch_masks = [] batch_mels = [] for sample in stream: text = sample.get("text", sample.get("transcription", "")) audio_obj = sample.get("audio") if not text or audio_obj is None: continue try: audio_array, sr = decode_audio_obj(audio_obj) except Exception: continue ids, mask = masked_text_tokenize(str(text), bert_tokenizer) try: mel = whisper_processor( audio_array, sampling_rate=sr, return_tensors="pt", ).input_features.squeeze(0) except Exception: continue batch_ids.append(ids) batch_masks.append(mask) batch_mels.append(mel) if len(batch_ids) >= batch_size: yield ( torch.stack(batch_ids), torch.stack(batch_masks), torch.stack(batch_mels), ) batch_ids, batch_masks, batch_mels = [], [], [] if batch_ids: yield ( torch.stack(batch_ids), torch.stack(batch_masks), torch.stack(batch_mels), ) def extract_protein_caption(sample: Dict[str, Any]) -> str: convos = sample.get("conversations", []) if isinstance(convos, list): for c in convos: if isinstance(c, dict) and c.get("from") == "gpt": v = str(c.get("value", "")).strip() if v: return v[:500] for c in convos: if isinstance(c, dict) and "value" in c: v = str(c["value"]).strip() if v: return v[:500] return str(sample.get("protein", "")).strip()[:500] def stream_protein_batches( stream, bert_tokenizer, esm_tokenizer, batch_size: int, ) -> Iterable[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]: batch_ids = [] batch_masks = [] batch_esm_ids = [] batch_esm_masks = [] for sample in stream: caption = extract_protein_caption(sample) seq = str(sample.get("amino_seq", sample.get("protein_sequence", ""))).strip() if len(caption) < 5 or len(seq) < 5: continue ids, mask = masked_text_tokenize(caption, bert_tokenizer) try: esm_t = esm_tokenizer( seq, padding="max_length", truncation=True, max_length=512, return_tensors="pt", ) except Exception: continue batch_ids.append(ids) batch_masks.append(mask) batch_esm_ids.append(esm_t["input_ids"].squeeze(0)) batch_esm_masks.append(esm_t["attention_mask"].squeeze(0)) if len(batch_ids) >= batch_size: yield ( torch.stack(batch_ids), torch.stack(batch_masks), torch.stack(batch_esm_ids), torch.stack(batch_esm_masks), ) batch_ids, batch_masks, batch_esm_ids, batch_esm_masks = [], [], [], [] if batch_ids: yield ( torch.stack(batch_ids), torch.stack(batch_masks), torch.stack(batch_esm_ids), torch.stack(batch_esm_masks), ) @torch.no_grad() def encode_streaming_batches( *, tag: str, expert_name: str, expert_hidden_shape: Tuple[int, int], batch_iter: Iterable, bert, expert_batch_forward: Callable[..., torch.Tensor], shard_size: int, row_keys: List[str], max_samples: Optional[int] = None, ) -> str: cache_path = os.path.join(CFG.cache_dir, tag) if os.path.exists(cache_path): print(f" Cache exists: {cache_path}") return cache_path features = Features({ "text_hidden": Array2D(shape=(CFG.max_text_len, CFG.bert_hidden_dim), dtype="float16"), "text_mask": Sequence(Value("bool"), length=CFG.max_text_len), f"{expert_name}_hidden": Array2D(shape=expert_hidden_shape, dtype="float16"), }) writer = ShardWriter( cache_dir=CFG.cache_dir, tag=tag, features=features, shard_size=shard_size, row_keys=row_keys, ) t0 = time.time() n = 0 for packed in batch_iter: # first two are always bert ids/masks text_ids = packed[0].to(DEVICE, non_blocking=True) text_mask = packed[1].to(DEVICE, non_blocking=True) text_hidden = bert( input_ids=text_ids, attention_mask=text_mask, ).last_hidden_state.detach().to(dtype=torch.float16).cpu().numpy() text_mask_np = text_mask.bool().cpu().numpy() expert_hidden = expert_batch_forward(*[p.to(DEVICE, non_blocking=True) for p in packed[2:]]) expert_hidden = expert_hidden.detach().to(dtype=torch.float16).cpu().numpy() for i in range(text_hidden.shape[0]): writer.add_row({ "text_hidden": text_hidden[i], "text_mask": text_mask_np[i].tolist(), f"{expert_name}_hidden": expert_hidden[i], }) n += text_hidden.shape[0] batch_size = text_hidden.shape[0] if n % 1000 < batch_size or n <= batch_size: rate = n / max(time.time() - t0, 1e-6) print(f" {n}" + (f"/{max_samples}" if max_samples else "") + f" ({rate:.0f}/s)") if max_samples is not None and n >= max_samples: break final_path = writer.finalize() print(f" Completed {n} samples in {time.time() - t0:.0f}s") return final_path # ============================================================================ # EXPERT RUNNERS # ============================================================================ def encode_image_expert(bert, split: str, tag: str, max_samples: Optional[int] = None) -> str: from transformers import Dinov2Model, AutoImageProcessor print(f"\n [IMAGE] Loading DINOv2-large + COCO-Caption ({split})...") dino = Dinov2Model.from_pretrained( "facebook/dinov2-large", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, ).to(DEVICE).eval() proc = AutoImageProcessor.from_pretrained("facebook/dinov2-large") tok = get_bert_tokenizer() hf_ds = load_dataset("lmms-lab/COCO-Caption", split=split) if max_samples is not None: hf_ds = hf_ds.select(range(min(max_samples, len(hf_ds)))) print(f" Dataset: {len(hf_ds)} samples") torch_ds = ImageTextDataset(hf_ds, tok, proc) loader = make_loader(torch_ds, batch_size=CFG.batch_size, num_workers=CFG.num_workers) def expert_forward(pixel_values): return dino(pixel_values=pixel_values).last_hidden_state path = encode_map_dataset( tag=tag, loader=loader, bert=bert, expert_name="image", expert_hidden_shape=(257, 1024), expert_forward=expert_forward, shard_size=CFG.shard_size_default, max_samples=max_samples, ) del dino, proc, hf_ds, torch_ds, loader cleanup_cuda() return path def encode_code_expert(bert, max_samples: Optional[int] = None) -> str: from transformers import RobertaModel, RobertaTokenizer print("\n [CODE] Loading CodeBERT + CodeSearchNet python...") codebert = RobertaModel.from_pretrained( "microsoft/codebert-base", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, ).to(DEVICE).eval() code_tok = RobertaTokenizer.from_pretrained("microsoft/codebert-base") tok = get_bert_tokenizer() hf_ds = load_dataset("code-search-net/code_search_net", "python", split="train") if max_samples is not None: hf_ds = hf_ds.select(range(min(max_samples, len(hf_ds)))) hf_ds = hf_ds.filter( lambda x: bool(x.get("func_documentation_string", "").strip()), num_proc=4, ) print(f" Dataset: {len(hf_ds)} samples (after filtering)") torch_ds = CodeTextDataset(hf_ds, tok, code_tok) loader = make_loader(torch_ds, batch_size=CFG.batch_size, num_workers=CFG.num_workers) def expert_forward(packed): code_ids = packed[:, 0].long() code_mask = packed[:, 1].long() return codebert(input_ids=code_ids, attention_mask=code_mask).last_hidden_state path = encode_map_dataset( tag="code_csn", loader=loader, bert=bert, expert_name="code", expert_hidden_shape=(256, 768), expert_forward=expert_forward, shard_size=CFG.shard_size_default, max_samples=max_samples, ) del codebert, code_tok, hf_ds, torch_ds, loader cleanup_cuda() return path def encode_audio_expert(bert, max_samples: Optional[int] = None) -> str: from transformers import WhisperModel, WhisperFeatureExtractor print("\n [AUDIO] Loading Whisper-large-v3 + LibriSpeech ASR (streaming)...") whisper = WhisperModel.from_pretrained( "openai/whisper-large-v3", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, ).to(DEVICE).eval() proc = WhisperFeatureExtractor.from_pretrained("openai/whisper-large-v3") tok = get_bert_tokenizer() max_n = max_samples or CFG.max_audio_samples audio_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 # probe shape probe_stream = load_dataset("openslr/librispeech_asr", "clean", split="train.100", streaming=True) probe_stream = probe_stream.cast_column("audio", Audio(sampling_rate=16000)) first = next(iter(probe_stream)) arr, sr = decode_audio_obj(first["audio"]) mel = proc(arr, sampling_rate=sr, return_tensors="pt").input_features mel = mel.to(device=DEVICE, dtype=audio_dtype) with torch.no_grad(): probe_hidden = whisper.encoder(mel).last_hidden_state seq_len, hidden_dim = probe_hidden.shape[1], probe_hidden.shape[2] print(f" Whisper encoder output: ({seq_len}, {hidden_dim})") del mel, probe_hidden stream = load_dataset("openslr/librispeech_asr", "clean", split="train.100", streaming=True) stream = stream.cast_column("audio", Audio(sampling_rate=16000)) batch_iter = stream_librispeech_batches( stream=stream, bert_tokenizer=tok, whisper_processor=proc, batch_size=16, ) def expert_batch_forward(mels: torch.Tensor) -> torch.Tensor: mels = mels.to(dtype=audio_dtype) return whisper.encoder(mels).last_hidden_state path = encode_streaming_batches( tag="audio_librispeech", expert_name="audio", expert_hidden_shape=(seq_len, hidden_dim), batch_iter=batch_iter, bert=bert, expert_batch_forward=expert_batch_forward, shard_size=256, # large hidden size; keep shards smaller row_keys=["text_hidden", "text_mask", "audio_hidden"], max_samples=max_n, ) del whisper, proc cleanup_cuda() return path def encode_protein_expert(bert, max_samples: Optional[int] = None) -> str: from transformers import EsmModel, EsmTokenizer print("\n [PROTEIN] Loading ESM-2-650M + Protein2Text-QA (streaming)...") esm = EsmModel.from_pretrained( "facebook/esm2_t33_650M_UR50D", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, ).to(DEVICE).eval() esm_tok = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") tok = get_bert_tokenizer() max_n = max_samples or CFG.max_protein_samples stream = load_dataset("tumorailab/Protein2Text-QA", split="test", streaming=True) batch_iter = stream_protein_batches( stream=stream, bert_tokenizer=tok, esm_tokenizer=esm_tok, batch_size=32, ) def expert_batch_forward(esm_ids: torch.Tensor, esm_mask: torch.Tensor) -> torch.Tensor: return esm(input_ids=esm_ids.long(), attention_mask=esm_mask.long()).last_hidden_state path = encode_streaming_batches( tag="protein_p2t", expert_name="protein", expert_hidden_shape=(512, 1280), batch_iter=batch_iter, bert=bert, expert_batch_forward=expert_batch_forward, shard_size=512, row_keys=["text_hidden", "text_mask", "protein_hidden"], max_samples=max_n, ) del esm, esm_tok cleanup_cuda() return path # ============================================================================ # MAIN # ============================================================================ def main(): ensure_dir(CFG.cache_dir) print("=" * 70) print("STAGE 1: MULTI-EXPERT PRECOMPUTE") print("=" * 70) if torch.cuda.is_available(): print(f"GPU: {torch.cuda.get_device_name()}") print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") print(f"Cache: {CFG.cache_dir}") required_tags = [ "image_coco", "image_coco_test", "audio_librispeech", "protein_p2t", "code_csn", ] missing = [t for t in required_tags if not os.path.exists(os.path.join(CFG.cache_dir, t))] if not missing: print("\nAll caches exist. Nothing to encode.") bert = None else: print(f"\nMissing caches: {missing}") if CFG.cleanup_hf_cache_between_experts: cleanup_hf_cache() bert = load_shared_bert() paths: Dict[str, Optional[str]] = {} # IMAGE TRAIN print(f"\n{'─' * 50}") print("[1/4] IMAGE — DINOv2 + COCO-Caption") if os.path.exists(os.path.join(CFG.cache_dir, "image_coco")): print(" [IMAGE] Cache exists, skipping.") paths["image"] = os.path.join(CFG.cache_dir, "image_coco") else: paths["image"] = encode_image_expert(bert, split="val", tag="image_coco") if CFG.cleanup_hf_cache_between_experts: cleanup_hf_cache() # IMAGE TEST if os.path.exists(os.path.join(CFG.cache_dir, "image_coco_test")): print("\n [IMAGE-TEST] Cache exists, skipping.") paths["image_test"] = os.path.join(CFG.cache_dir, "image_coco_test") else: print("\n [IMAGE-TEST] COCO test split...") paths["image_test"] = encode_image_expert(bert, split="test", tag="image_coco_test") if CFG.cleanup_hf_cache_between_experts: cleanup_hf_cache() # AUDIO print(f"\n{'─' * 50}") print("[2/4] AUDIO — Whisper + LibriSpeech") if os.path.exists(os.path.join(CFG.cache_dir, "audio_librispeech")): print(" [AUDIO] Cache exists, skipping.") paths["audio"] = os.path.join(CFG.cache_dir, "audio_librispeech") else: try: paths["audio"] = encode_audio_expert(bert, max_samples=CFG.max_audio_samples) except Exception as e: print(f" AUDIO failed: {e}") paths["audio"] = None if CFG.cleanup_hf_cache_between_experts: cleanup_hf_cache() # PROTEIN print(f"\n{'─' * 50}") print("[3/4] PROTEIN — ESM-2 + Protein2Text-QA") if os.path.exists(os.path.join(CFG.cache_dir, "protein_p2t")): print(" [PROTEIN] Cache exists, skipping.") paths["protein"] = os.path.join(CFG.cache_dir, "protein_p2t") else: try: paths["protein"] = encode_protein_expert(bert, max_samples=CFG.max_protein_samples) except Exception as e: print(f" PROTEIN failed: {e}") paths["protein"] = None if CFG.cleanup_hf_cache_between_experts: cleanup_hf_cache() # CODE print(f"\n{'─' * 50}") print("[4/4] CODE — CodeBERT + CodeSearchNet Python") if os.path.exists(os.path.join(CFG.cache_dir, "code_csn")): print(" [CODE] Cache exists, skipping.") paths["code"] = os.path.join(CFG.cache_dir, "code_csn") else: try: paths["code"] = encode_code_expert(bert, max_samples=CFG.max_code_samples) except Exception as e: print(f" CODE failed: {e}") paths["code"] = None if CFG.cleanup_hf_cache_between_experts: cleanup_hf_cache() if bert is not None: del bert cleanup_cuda() flickr_path = os.path.join(CFG.cache_dir, "flickr30k") if os.path.exists(flickr_path): paths["flickr"] = flickr_path print(f"\n{'=' * 70}") print("CACHE SUMMARY") print(f"{'=' * 70}") for name, path in paths.items(): if path and os.path.exists(path): ds = load_from_disk(path) print(f" {name:15s}: {len(ds):6d} pairs [{path}]") print("\nReady for Stage 2 multi-expert training.") print("Done.") if __name__ == "__main__": main()