| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import subprocess |
| | import sys |
| |
|
| | try: |
| | import sympy |
| | _ = sympy.core |
| | except (ImportError, AttributeError): |
| | subprocess.check_call( |
| | [sys.executable, "-m", "pip", "install", "--upgrade", "sympy", "--break-system-packages", "-q"] |
| | ) |
| |
|
| | import gc |
| | import os |
| | import shutil |
| | import time |
| | from dataclasses import dataclass |
| | from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple |
| |
|
| | import numpy as np |
| | import torch |
| | from torch.utils.data import Dataset, DataLoader |
| | from datasets import ( |
| | Audio, |
| | Dataset as HFDataset, |
| | Features, |
| | Sequence, |
| | Value, |
| | Array2D, |
| | concatenate_datasets, |
| | load_dataset, |
| | load_from_disk, |
| | ) |
| |
|
| | |
| | |
| | |
| |
|
| | @dataclass |
| | class BaseConfig: |
| | cache_dir: str = "/home/claude/geo_cache" |
| | max_text_len: int = 32 |
| |
|
| | device: str = "cuda" if torch.cuda.is_available() else "cpu" |
| | amp_enabled: bool = torch.cuda.is_available() |
| |
|
| | bert_model_name: str = "google-bert/bert-large-uncased" |
| | bert_hidden_dim: int = 1024 |
| |
|
| | batch_size: int = 256 |
| | num_workers: int = 8 |
| | prefetch_factor: int = 2 |
| | pin_memory: bool = torch.cuda.is_available() |
| |
|
| | shard_size_default: int = 2048 |
| |
|
| | |
| | max_audio_samples: int = 10000 |
| | max_protein_samples: int = 15000 |
| | max_code_samples: int = 50000 |
| |
|
| | cleanup_hf_cache_between_experts: bool = True |
| |
|
| |
|
| | CFG = BaseConfig() |
| | DEVICE = torch.device(CFG.device) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def cleanup_hf_cache() -> None: |
| | """Delete HF datasets/hub cache to free disk after encoding an expert.""" |
| | hf_cache = os.path.expanduser("~/.cache/huggingface") |
| | for subdir in ["datasets", "hub"]: |
| | p = os.path.join(hf_cache, subdir) |
| | if not os.path.exists(p): |
| | continue |
| |
|
| | size_gb = 0.0 |
| | for dp, _, files in os.walk(p): |
| | for f in files: |
| | fp = os.path.join(dp, f) |
| | try: |
| | size_gb += os.path.getsize(fp) |
| | except OSError: |
| | pass |
| | size_gb /= 1e9 |
| |
|
| | print(f" Cleaning {p} ({size_gb:.1f} GB)...") |
| | shutil.rmtree(p, ignore_errors=True) |
| | os.makedirs(p, exist_ok=True) |
| |
|
| |
|
| | def cleanup_cuda() -> None: |
| | gc.collect() |
| | if torch.cuda.is_available(): |
| | torch.cuda.empty_cache() |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | _bert_tokenizer = None |
| |
|
| | def get_bert_tokenizer(): |
| | global _bert_tokenizer |
| | if _bert_tokenizer is None: |
| | from transformers import BertTokenizer |
| | _bert_tokenizer = BertTokenizer.from_pretrained(CFG.bert_model_name) |
| | return _bert_tokenizer |
| |
|
| |
|
| | def load_shared_bert(): |
| | from transformers import BertModel |
| | print("Loading shared BERT-large...") |
| | bert = BertModel.from_pretrained( |
| | CFG.bert_model_name, |
| | torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
| | ).to(DEVICE).eval() |
| | print(" BERT ready.") |
| | return bert |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def ensure_dir(path: str) -> None: |
| | os.makedirs(path, exist_ok=True) |
| |
|
| |
|
| | def make_loader(ds: Dataset, batch_size: int, num_workers: int) -> DataLoader: |
| | kwargs = dict( |
| | dataset=ds, |
| | batch_size=batch_size, |
| | shuffle=False, |
| | num_workers=num_workers, |
| | pin_memory=CFG.pin_memory, |
| | persistent_workers=num_workers > 0, |
| | ) |
| | if num_workers > 0: |
| | kwargs["prefetch_factor"] = CFG.prefetch_factor |
| | return DataLoader(**kwargs) |
| |
|
| |
|
| | def masked_text_tokenize(text: str, tokenizer) -> Tuple[torch.Tensor, torch.Tensor]: |
| | tok = tokenizer( |
| | text, |
| | padding="max_length", |
| | truncation=True, |
| | max_length=CFG.max_text_len, |
| | return_tensors="pt", |
| | ) |
| | return tok["input_ids"].squeeze(0), tok["attention_mask"].squeeze(0) |
| |
|
| |
|
| | def extract_first_text(sample: Dict[str, Any], keys: List[str]) -> str: |
| | for key in keys: |
| | if key not in sample: |
| | continue |
| | value = sample[key] |
| |
|
| | if isinstance(value, str): |
| | value = value.strip() |
| | if value: |
| | return value |
| |
|
| | if isinstance(value, list) and value: |
| | first = value[0] |
| | if isinstance(first, str): |
| | first = first.strip() |
| | if first: |
| | return first |
| | if isinstance(first, dict): |
| | txt = str(first.get("raw", first.get("text", ""))).strip() |
| | if txt: |
| | return txt |
| | txt = str(first).strip() |
| | if txt: |
| | return txt |
| |
|
| | return "" |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class ShardWriter: |
| | def __init__( |
| | self, |
| | cache_dir: str, |
| | tag: str, |
| | features: Features, |
| | shard_size: int, |
| | row_keys: List[str], |
| | ): |
| | self.cache_dir = cache_dir |
| | self.tag = tag |
| | self.features = features |
| | self.shard_size = shard_size |
| | self.row_keys = row_keys |
| |
|
| | self.cache_path = os.path.join(cache_dir, tag) |
| | self.shard_root = os.path.join(cache_dir, f"{tag}__shards") |
| |
|
| | self.rows = {k: [] for k in row_keys} |
| | self.shard_paths: List[str] = [] |
| | self.shard_idx = 0 |
| | self.n_written = 0 |
| |
|
| | @property |
| | def exists(self) -> bool: |
| | return os.path.exists(self.cache_path) |
| |
|
| | def add_row(self, row: Dict[str, Any]) -> None: |
| | for k in self.row_keys: |
| | self.rows[k].append(row[k]) |
| |
|
| | if len(self.rows[self.row_keys[0]]) >= self.shard_size: |
| | self.flush() |
| |
|
| | def flush(self) -> None: |
| | n_rows = len(self.rows[self.row_keys[0]]) |
| | if n_rows == 0: |
| | return |
| |
|
| | ensure_dir(self.shard_root) |
| | shard_path = os.path.join(self.shard_root, f"shard_{self.shard_idx:05d}") |
| | ds = HFDataset.from_dict(self.rows, features=self.features) |
| | ds.save_to_disk(shard_path) |
| |
|
| | self.shard_paths.append(shard_path) |
| | self.shard_idx += 1 |
| | self.n_written += n_rows |
| | self.rows = {k: [] for k in self.row_keys} |
| |
|
| | def finalize(self) -> str: |
| | self.flush() |
| |
|
| | print(f" Merging {len(self.shard_paths)} shards...") |
| | merged = concatenate_datasets([load_from_disk(p) for p in self.shard_paths]) |
| | merged.save_to_disk(self.cache_path) |
| | print(f" Saved {len(merged)} pairs to {self.cache_path}") |
| |
|
| | if os.path.exists(self.shard_root): |
| | shutil.rmtree(self.shard_root, ignore_errors=True) |
| |
|
| | return self.cache_path |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class ImageTextDataset(Dataset): |
| | def __init__(self, hf_ds, bert_tokenizer, image_processor): |
| | self.ds = hf_ds |
| | self.tok = bert_tokenizer |
| | self.proc = image_processor |
| | self.fallback_shape = (3, 518, 518) |
| |
|
| | def __len__(self): |
| | return len(self.ds) |
| |
|
| | def __getitem__(self, idx): |
| | sample = self.ds[idx] |
| |
|
| | caption = extract_first_text( |
| | sample, |
| | ["answer", "caption", "captions", "text", "original_alt_text"], |
| | ) |
| | ids, mask = masked_text_tokenize(caption, self.tok) |
| |
|
| | image = sample.get("image", None) |
| | valid = True |
| |
|
| | if image is not None and hasattr(image, "convert"): |
| | try: |
| | expert_input = self.proc( |
| | images=image.convert("RGB"), |
| | return_tensors="pt", |
| | )["pixel_values"].squeeze(0) |
| | except Exception: |
| | expert_input = torch.zeros(self.fallback_shape, dtype=torch.float32) |
| | valid = False |
| | else: |
| | expert_input = torch.zeros(self.fallback_shape, dtype=torch.float32) |
| | valid = False |
| |
|
| | return ids, mask, expert_input, valid |
| |
|
| |
|
| | class CodeTextDataset(Dataset): |
| | def __init__(self, hf_ds, bert_tokenizer, code_tokenizer): |
| | self.ds = hf_ds |
| | self.tok = bert_tokenizer |
| | self.code_tok = code_tokenizer |
| |
|
| | def __len__(self): |
| | return len(self.ds) |
| |
|
| | def __getitem__(self, idx): |
| | sample = self.ds[idx] |
| |
|
| | doc = sample.get("func_documentation_string", "") |
| | if not doc or not doc.strip(): |
| | doc = str(sample.get("whole_func_string", ""))[:200] |
| | doc = str(doc).strip()[:500] |
| |
|
| | ids, mask = masked_text_tokenize(doc, self.tok) |
| |
|
| | code = str(sample.get("func_code_string", sample.get("whole_func_string", ""))).strip()[:512] |
| | valid = len(code) > 5 and len(doc) > 5 |
| |
|
| | if valid: |
| | try: |
| | tok = self.code_tok( |
| | code, |
| | padding="max_length", |
| | truncation=True, |
| | max_length=256, |
| | return_tensors="pt", |
| | ) |
| | code_ids = tok["input_ids"].squeeze(0) |
| | code_mask = tok["attention_mask"].squeeze(0) |
| | except Exception: |
| | code_ids = torch.zeros(256, dtype=torch.long) |
| | code_mask = torch.zeros(256, dtype=torch.long) |
| | valid = False |
| | else: |
| | code_ids = torch.zeros(256, dtype=torch.long) |
| | code_mask = torch.zeros(256, dtype=torch.long) |
| |
|
| | return ids, mask, torch.stack([code_ids, code_mask]), valid |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | @torch.no_grad() |
| | def encode_map_dataset( |
| | *, |
| | tag: str, |
| | loader: DataLoader, |
| | bert, |
| | expert_name: str, |
| | expert_hidden_shape: Tuple[int, int], |
| | expert_forward: Callable[[torch.Tensor], torch.Tensor], |
| | shard_size: int, |
| | max_samples: Optional[int] = None, |
| | ) -> str: |
| | cache_path = os.path.join(CFG.cache_dir, tag) |
| | if os.path.exists(cache_path): |
| | print(f" Cache exists: {cache_path}") |
| | return cache_path |
| |
|
| | features = Features({ |
| | "text_hidden": Array2D(shape=(CFG.max_text_len, CFG.bert_hidden_dim), dtype="float16"), |
| | "text_mask": Sequence(Value("bool"), length=CFG.max_text_len), |
| | f"{expert_name}_hidden": Array2D(shape=expert_hidden_shape, dtype="float16"), |
| | }) |
| |
|
| | writer = ShardWriter( |
| | cache_dir=CFG.cache_dir, |
| | tag=tag, |
| | features=features, |
| | shard_size=shard_size, |
| | row_keys=["text_hidden", "text_mask", f"{expert_name}_hidden"], |
| | ) |
| |
|
| | t0 = time.time() |
| | n = 0 |
| |
|
| | for batch in loader: |
| | text_ids, text_mask, expert_input, valid = batch |
| | valid_b = valid.bool() |
| |
|
| | if not valid_b.any(): |
| | continue |
| |
|
| | text_ids = text_ids[valid_b].to(DEVICE, non_blocking=True) |
| | text_mask_gpu = text_mask[valid_b].to(DEVICE, non_blocking=True) |
| | expert_input = expert_input[valid_b].to(DEVICE, non_blocking=True) |
| |
|
| | text_hidden = bert( |
| | input_ids=text_ids, |
| | attention_mask=text_mask_gpu, |
| | ).last_hidden_state.detach().to(dtype=torch.float16).cpu().numpy() |
| |
|
| | text_mask_np = text_mask_gpu.bool().cpu().numpy() |
| | expert_hidden = expert_forward(expert_input).detach().to(dtype=torch.float16).cpu().numpy() |
| |
|
| | for i in range(text_hidden.shape[0]): |
| | writer.add_row({ |
| | "text_hidden": text_hidden[i], |
| | "text_mask": text_mask_np[i].tolist(), |
| | f"{expert_name}_hidden": expert_hidden[i], |
| | }) |
| |
|
| | n += text_hidden.shape[0] |
| | if n % 1000 < CFG.batch_size or n <= CFG.batch_size: |
| | rate = n / max(time.time() - t0, 1e-6) |
| | print(f" {n}" + (f"/{max_samples}" if max_samples else "") + f" ({rate:.0f}/s)") |
| |
|
| | if max_samples is not None and n >= max_samples: |
| | break |
| |
|
| | final_path = writer.finalize() |
| | print(f" Completed {n} samples in {time.time() - t0:.0f}s") |
| | return final_path |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def decode_audio_obj(audio_obj) -> Tuple[np.ndarray, int]: |
| | if hasattr(audio_obj, "get_all_samples"): |
| | samples = audio_obj.get_all_samples() |
| | arr = samples.data.numpy().squeeze() |
| | sr = samples.sample_rate |
| | return arr, sr |
| |
|
| | if isinstance(audio_obj, dict): |
| | return audio_obj["array"], audio_obj.get("sampling_rate", 16000) |
| |
|
| | raise TypeError(f"Unsupported audio object type: {type(audio_obj)}") |
| |
|
| |
|
| | def stream_librispeech_batches( |
| | stream, |
| | bert_tokenizer, |
| | whisper_processor, |
| | batch_size: int, |
| | ) -> Iterable[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: |
| | batch_ids = [] |
| | batch_masks = [] |
| | batch_mels = [] |
| |
|
| | for sample in stream: |
| | text = sample.get("text", sample.get("transcription", "")) |
| | audio_obj = sample.get("audio") |
| | if not text or audio_obj is None: |
| | continue |
| |
|
| | try: |
| | audio_array, sr = decode_audio_obj(audio_obj) |
| | except Exception: |
| | continue |
| |
|
| | ids, mask = masked_text_tokenize(str(text), bert_tokenizer) |
| |
|
| | try: |
| | mel = whisper_processor( |
| | audio_array, |
| | sampling_rate=sr, |
| | return_tensors="pt", |
| | ).input_features.squeeze(0) |
| | except Exception: |
| | continue |
| |
|
| | batch_ids.append(ids) |
| | batch_masks.append(mask) |
| | batch_mels.append(mel) |
| |
|
| | if len(batch_ids) >= batch_size: |
| | yield ( |
| | torch.stack(batch_ids), |
| | torch.stack(batch_masks), |
| | torch.stack(batch_mels), |
| | ) |
| | batch_ids, batch_masks, batch_mels = [], [], [] |
| |
|
| | if batch_ids: |
| | yield ( |
| | torch.stack(batch_ids), |
| | torch.stack(batch_masks), |
| | torch.stack(batch_mels), |
| | ) |
| |
|
| |
|
| | def extract_protein_caption(sample: Dict[str, Any]) -> str: |
| | convos = sample.get("conversations", []) |
| | if isinstance(convos, list): |
| | for c in convos: |
| | if isinstance(c, dict) and c.get("from") == "gpt": |
| | v = str(c.get("value", "")).strip() |
| | if v: |
| | return v[:500] |
| | for c in convos: |
| | if isinstance(c, dict) and "value" in c: |
| | v = str(c["value"]).strip() |
| | if v: |
| | return v[:500] |
| |
|
| | return str(sample.get("protein", "")).strip()[:500] |
| |
|
| |
|
| | def stream_protein_batches( |
| | stream, |
| | bert_tokenizer, |
| | esm_tokenizer, |
| | batch_size: int, |
| | ) -> Iterable[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]: |
| | batch_ids = [] |
| | batch_masks = [] |
| | batch_esm_ids = [] |
| | batch_esm_masks = [] |
| |
|
| | for sample in stream: |
| | caption = extract_protein_caption(sample) |
| | seq = str(sample.get("amino_seq", sample.get("protein_sequence", ""))).strip() |
| |
|
| | if len(caption) < 5 or len(seq) < 5: |
| | continue |
| |
|
| | ids, mask = masked_text_tokenize(caption, bert_tokenizer) |
| |
|
| | try: |
| | esm_t = esm_tokenizer( |
| | seq, |
| | padding="max_length", |
| | truncation=True, |
| | max_length=512, |
| | return_tensors="pt", |
| | ) |
| | except Exception: |
| | continue |
| |
|
| | batch_ids.append(ids) |
| | batch_masks.append(mask) |
| | batch_esm_ids.append(esm_t["input_ids"].squeeze(0)) |
| | batch_esm_masks.append(esm_t["attention_mask"].squeeze(0)) |
| |
|
| | if len(batch_ids) >= batch_size: |
| | yield ( |
| | torch.stack(batch_ids), |
| | torch.stack(batch_masks), |
| | torch.stack(batch_esm_ids), |
| | torch.stack(batch_esm_masks), |
| | ) |
| | batch_ids, batch_masks, batch_esm_ids, batch_esm_masks = [], [], [], [] |
| |
|
| | if batch_ids: |
| | yield ( |
| | torch.stack(batch_ids), |
| | torch.stack(batch_masks), |
| | torch.stack(batch_esm_ids), |
| | torch.stack(batch_esm_masks), |
| | ) |
| |
|
| |
|
| | @torch.no_grad() |
| | def encode_streaming_batches( |
| | *, |
| | tag: str, |
| | expert_name: str, |
| | expert_hidden_shape: Tuple[int, int], |
| | batch_iter: Iterable, |
| | bert, |
| | expert_batch_forward: Callable[..., torch.Tensor], |
| | shard_size: int, |
| | row_keys: List[str], |
| | max_samples: Optional[int] = None, |
| | ) -> str: |
| | cache_path = os.path.join(CFG.cache_dir, tag) |
| | if os.path.exists(cache_path): |
| | print(f" Cache exists: {cache_path}") |
| | return cache_path |
| |
|
| | features = Features({ |
| | "text_hidden": Array2D(shape=(CFG.max_text_len, CFG.bert_hidden_dim), dtype="float16"), |
| | "text_mask": Sequence(Value("bool"), length=CFG.max_text_len), |
| | f"{expert_name}_hidden": Array2D(shape=expert_hidden_shape, dtype="float16"), |
| | }) |
| |
|
| | writer = ShardWriter( |
| | cache_dir=CFG.cache_dir, |
| | tag=tag, |
| | features=features, |
| | shard_size=shard_size, |
| | row_keys=row_keys, |
| | ) |
| |
|
| | t0 = time.time() |
| | n = 0 |
| |
|
| | for packed in batch_iter: |
| | |
| | text_ids = packed[0].to(DEVICE, non_blocking=True) |
| | text_mask = packed[1].to(DEVICE, non_blocking=True) |
| |
|
| | text_hidden = bert( |
| | input_ids=text_ids, |
| | attention_mask=text_mask, |
| | ).last_hidden_state.detach().to(dtype=torch.float16).cpu().numpy() |
| |
|
| | text_mask_np = text_mask.bool().cpu().numpy() |
| |
|
| | expert_hidden = expert_batch_forward(*[p.to(DEVICE, non_blocking=True) for p in packed[2:]]) |
| | expert_hidden = expert_hidden.detach().to(dtype=torch.float16).cpu().numpy() |
| |
|
| | for i in range(text_hidden.shape[0]): |
| | writer.add_row({ |
| | "text_hidden": text_hidden[i], |
| | "text_mask": text_mask_np[i].tolist(), |
| | f"{expert_name}_hidden": expert_hidden[i], |
| | }) |
| |
|
| | n += text_hidden.shape[0] |
| | batch_size = text_hidden.shape[0] |
| | if n % 1000 < batch_size or n <= batch_size: |
| | rate = n / max(time.time() - t0, 1e-6) |
| | print(f" {n}" + (f"/{max_samples}" if max_samples else "") + f" ({rate:.0f}/s)") |
| |
|
| | if max_samples is not None and n >= max_samples: |
| | break |
| |
|
| | final_path = writer.finalize() |
| | print(f" Completed {n} samples in {time.time() - t0:.0f}s") |
| | return final_path |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def encode_image_expert(bert, split: str, tag: str, max_samples: Optional[int] = None) -> str: |
| | from transformers import Dinov2Model, AutoImageProcessor |
| |
|
| | print(f"\n [IMAGE] Loading DINOv2-large + COCO-Caption ({split})...") |
| | dino = Dinov2Model.from_pretrained( |
| | "facebook/dinov2-large", |
| | torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
| | ).to(DEVICE).eval() |
| | proc = AutoImageProcessor.from_pretrained("facebook/dinov2-large") |
| | tok = get_bert_tokenizer() |
| |
|
| | hf_ds = load_dataset("lmms-lab/COCO-Caption", split=split) |
| | if max_samples is not None: |
| | hf_ds = hf_ds.select(range(min(max_samples, len(hf_ds)))) |
| | print(f" Dataset: {len(hf_ds)} samples") |
| |
|
| | torch_ds = ImageTextDataset(hf_ds, tok, proc) |
| | loader = make_loader(torch_ds, batch_size=CFG.batch_size, num_workers=CFG.num_workers) |
| |
|
| | def expert_forward(pixel_values): |
| | return dino(pixel_values=pixel_values).last_hidden_state |
| |
|
| | path = encode_map_dataset( |
| | tag=tag, |
| | loader=loader, |
| | bert=bert, |
| | expert_name="image", |
| | expert_hidden_shape=(257, 1024), |
| | expert_forward=expert_forward, |
| | shard_size=CFG.shard_size_default, |
| | max_samples=max_samples, |
| | ) |
| |
|
| | del dino, proc, hf_ds, torch_ds, loader |
| | cleanup_cuda() |
| | return path |
| |
|
| |
|
| | def encode_code_expert(bert, max_samples: Optional[int] = None) -> str: |
| | from transformers import RobertaModel, RobertaTokenizer |
| |
|
| | print("\n [CODE] Loading CodeBERT + CodeSearchNet python...") |
| | codebert = RobertaModel.from_pretrained( |
| | "microsoft/codebert-base", |
| | torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
| | ).to(DEVICE).eval() |
| | code_tok = RobertaTokenizer.from_pretrained("microsoft/codebert-base") |
| | tok = get_bert_tokenizer() |
| |
|
| | hf_ds = load_dataset("code-search-net/code_search_net", "python", split="train") |
| | if max_samples is not None: |
| | hf_ds = hf_ds.select(range(min(max_samples, len(hf_ds)))) |
| |
|
| | hf_ds = hf_ds.filter( |
| | lambda x: bool(x.get("func_documentation_string", "").strip()), |
| | num_proc=4, |
| | ) |
| | print(f" Dataset: {len(hf_ds)} samples (after filtering)") |
| |
|
| | torch_ds = CodeTextDataset(hf_ds, tok, code_tok) |
| | loader = make_loader(torch_ds, batch_size=CFG.batch_size, num_workers=CFG.num_workers) |
| |
|
| | def expert_forward(packed): |
| | code_ids = packed[:, 0].long() |
| | code_mask = packed[:, 1].long() |
| | return codebert(input_ids=code_ids, attention_mask=code_mask).last_hidden_state |
| |
|
| | path = encode_map_dataset( |
| | tag="code_csn", |
| | loader=loader, |
| | bert=bert, |
| | expert_name="code", |
| | expert_hidden_shape=(256, 768), |
| | expert_forward=expert_forward, |
| | shard_size=CFG.shard_size_default, |
| | max_samples=max_samples, |
| | ) |
| |
|
| | del codebert, code_tok, hf_ds, torch_ds, loader |
| | cleanup_cuda() |
| | return path |
| |
|
| |
|
| | def encode_audio_expert(bert, max_samples: Optional[int] = None) -> str: |
| | from transformers import WhisperModel, WhisperFeatureExtractor |
| |
|
| | print("\n [AUDIO] Loading Whisper-large-v3 + LibriSpeech ASR (streaming)...") |
| | whisper = WhisperModel.from_pretrained( |
| | "openai/whisper-large-v3", |
| | torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
| | ).to(DEVICE).eval() |
| | proc = WhisperFeatureExtractor.from_pretrained("openai/whisper-large-v3") |
| | tok = get_bert_tokenizer() |
| |
|
| | max_n = max_samples or CFG.max_audio_samples |
| | audio_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
| |
|
| | |
| | probe_stream = load_dataset("openslr/librispeech_asr", "clean", split="train.100", streaming=True) |
| | probe_stream = probe_stream.cast_column("audio", Audio(sampling_rate=16000)) |
| | first = next(iter(probe_stream)) |
| | arr, sr = decode_audio_obj(first["audio"]) |
| |
|
| | mel = proc(arr, sampling_rate=sr, return_tensors="pt").input_features |
| | mel = mel.to(device=DEVICE, dtype=audio_dtype) |
| |
|
| | with torch.no_grad(): |
| | probe_hidden = whisper.encoder(mel).last_hidden_state |
| |
|
| | seq_len, hidden_dim = probe_hidden.shape[1], probe_hidden.shape[2] |
| | print(f" Whisper encoder output: ({seq_len}, {hidden_dim})") |
| | del mel, probe_hidden |
| |
|
| | stream = load_dataset("openslr/librispeech_asr", "clean", split="train.100", streaming=True) |
| | stream = stream.cast_column("audio", Audio(sampling_rate=16000)) |
| |
|
| | batch_iter = stream_librispeech_batches( |
| | stream=stream, |
| | bert_tokenizer=tok, |
| | whisper_processor=proc, |
| | batch_size=16, |
| | ) |
| |
|
| | def expert_batch_forward(mels: torch.Tensor) -> torch.Tensor: |
| | mels = mels.to(dtype=audio_dtype) |
| | return whisper.encoder(mels).last_hidden_state |
| |
|
| | path = encode_streaming_batches( |
| | tag="audio_librispeech", |
| | expert_name="audio", |
| | expert_hidden_shape=(seq_len, hidden_dim), |
| | batch_iter=batch_iter, |
| | bert=bert, |
| | expert_batch_forward=expert_batch_forward, |
| | shard_size=256, |
| | row_keys=["text_hidden", "text_mask", "audio_hidden"], |
| | max_samples=max_n, |
| | ) |
| |
|
| | del whisper, proc |
| | cleanup_cuda() |
| | return path |
| |
|
| |
|
| | def encode_protein_expert(bert, max_samples: Optional[int] = None) -> str: |
| | from transformers import EsmModel, EsmTokenizer |
| |
|
| | print("\n [PROTEIN] Loading ESM-2-650M + Protein2Text-QA (streaming)...") |
| | esm = EsmModel.from_pretrained( |
| | "facebook/esm2_t33_650M_UR50D", |
| | torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
| | ).to(DEVICE).eval() |
| | esm_tok = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") |
| | tok = get_bert_tokenizer() |
| |
|
| | max_n = max_samples or CFG.max_protein_samples |
| | stream = load_dataset("tumorailab/Protein2Text-QA", split="test", streaming=True) |
| |
|
| | batch_iter = stream_protein_batches( |
| | stream=stream, |
| | bert_tokenizer=tok, |
| | esm_tokenizer=esm_tok, |
| | batch_size=32, |
| | ) |
| |
|
| | def expert_batch_forward(esm_ids: torch.Tensor, esm_mask: torch.Tensor) -> torch.Tensor: |
| | return esm(input_ids=esm_ids.long(), attention_mask=esm_mask.long()).last_hidden_state |
| |
|
| | path = encode_streaming_batches( |
| | tag="protein_p2t", |
| | expert_name="protein", |
| | expert_hidden_shape=(512, 1280), |
| | batch_iter=batch_iter, |
| | bert=bert, |
| | expert_batch_forward=expert_batch_forward, |
| | shard_size=512, |
| | row_keys=["text_hidden", "text_mask", "protein_hidden"], |
| | max_samples=max_n, |
| | ) |
| |
|
| | del esm, esm_tok |
| | cleanup_cuda() |
| | return path |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def main(): |
| | ensure_dir(CFG.cache_dir) |
| |
|
| | print("=" * 70) |
| | print("STAGE 1: MULTI-EXPERT PRECOMPUTE") |
| | print("=" * 70) |
| |
|
| | if torch.cuda.is_available(): |
| | print(f"GPU: {torch.cuda.get_device_name()}") |
| | print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") |
| | print(f"Cache: {CFG.cache_dir}") |
| |
|
| | required_tags = [ |
| | "image_coco", |
| | "image_coco_test", |
| | "audio_librispeech", |
| | "protein_p2t", |
| | "code_csn", |
| | ] |
| | missing = [t for t in required_tags if not os.path.exists(os.path.join(CFG.cache_dir, t))] |
| |
|
| | if not missing: |
| | print("\nAll caches exist. Nothing to encode.") |
| | bert = None |
| | else: |
| | print(f"\nMissing caches: {missing}") |
| | if CFG.cleanup_hf_cache_between_experts: |
| | cleanup_hf_cache() |
| | bert = load_shared_bert() |
| |
|
| | paths: Dict[str, Optional[str]] = {} |
| |
|
| | |
| | print(f"\n{'─' * 50}") |
| | print("[1/4] IMAGE — DINOv2 + COCO-Caption") |
| | if os.path.exists(os.path.join(CFG.cache_dir, "image_coco")): |
| | print(" [IMAGE] Cache exists, skipping.") |
| | paths["image"] = os.path.join(CFG.cache_dir, "image_coco") |
| | else: |
| | paths["image"] = encode_image_expert(bert, split="val", tag="image_coco") |
| | if CFG.cleanup_hf_cache_between_experts: |
| | cleanup_hf_cache() |
| |
|
| | |
| | if os.path.exists(os.path.join(CFG.cache_dir, "image_coco_test")): |
| | print("\n [IMAGE-TEST] Cache exists, skipping.") |
| | paths["image_test"] = os.path.join(CFG.cache_dir, "image_coco_test") |
| | else: |
| | print("\n [IMAGE-TEST] COCO test split...") |
| | paths["image_test"] = encode_image_expert(bert, split="test", tag="image_coco_test") |
| | if CFG.cleanup_hf_cache_between_experts: |
| | cleanup_hf_cache() |
| |
|
| | |
| | print(f"\n{'─' * 50}") |
| | print("[2/4] AUDIO — Whisper + LibriSpeech") |
| | if os.path.exists(os.path.join(CFG.cache_dir, "audio_librispeech")): |
| | print(" [AUDIO] Cache exists, skipping.") |
| | paths["audio"] = os.path.join(CFG.cache_dir, "audio_librispeech") |
| | else: |
| | try: |
| | paths["audio"] = encode_audio_expert(bert, max_samples=CFG.max_audio_samples) |
| | except Exception as e: |
| | print(f" AUDIO failed: {e}") |
| | paths["audio"] = None |
| | if CFG.cleanup_hf_cache_between_experts: |
| | cleanup_hf_cache() |
| |
|
| | |
| | print(f"\n{'─' * 50}") |
| | print("[3/4] PROTEIN — ESM-2 + Protein2Text-QA") |
| | if os.path.exists(os.path.join(CFG.cache_dir, "protein_p2t")): |
| | print(" [PROTEIN] Cache exists, skipping.") |
| | paths["protein"] = os.path.join(CFG.cache_dir, "protein_p2t") |
| | else: |
| | try: |
| | paths["protein"] = encode_protein_expert(bert, max_samples=CFG.max_protein_samples) |
| | except Exception as e: |
| | print(f" PROTEIN failed: {e}") |
| | paths["protein"] = None |
| | if CFG.cleanup_hf_cache_between_experts: |
| | cleanup_hf_cache() |
| |
|
| | |
| | print(f"\n{'─' * 50}") |
| | print("[4/4] CODE — CodeBERT + CodeSearchNet Python") |
| | if os.path.exists(os.path.join(CFG.cache_dir, "code_csn")): |
| | print(" [CODE] Cache exists, skipping.") |
| | paths["code"] = os.path.join(CFG.cache_dir, "code_csn") |
| | else: |
| | try: |
| | paths["code"] = encode_code_expert(bert, max_samples=CFG.max_code_samples) |
| | except Exception as e: |
| | print(f" CODE failed: {e}") |
| | paths["code"] = None |
| | if CFG.cleanup_hf_cache_between_experts: |
| | cleanup_hf_cache() |
| |
|
| | if bert is not None: |
| | del bert |
| | cleanup_cuda() |
| |
|
| | flickr_path = os.path.join(CFG.cache_dir, "flickr30k") |
| | if os.path.exists(flickr_path): |
| | paths["flickr"] = flickr_path |
| |
|
| | print(f"\n{'=' * 70}") |
| | print("CACHE SUMMARY") |
| | print(f"{'=' * 70}") |
| | for name, path in paths.items(): |
| | if path and os.path.exists(path): |
| | ds = load_from_disk(path) |
| | print(f" {name:15s}: {len(ds):6d} pairs [{path}]") |
| |
|
| | print("\nReady for Stage 2 multi-expert training.") |
| | print("Done.") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |