| """ |
| Bertint V8 Dataset — Per-Residue Protein Embeddings for Cross-Attention |
| |
| Like V7 but keeps per-residue ESM-C embeddings [L, D] instead of |
| mean-pooling to [D]. This enables token-level cross-attention between |
| glycan tokens and protein residues. |
| |
| Changes from V7: |
| - protein_emb: [Lp, 960] per-residue (not [960] mean-pooled) |
| - collate_fn: pads protein sequences to max length in batch |
| - Returns protein_mask for cross-attention padding |
| """ |
|
|
| import json |
| import logging |
| import os |
| from pathlib import Path |
| from typing import Dict, List, Optional, Tuple |
|
|
| import h5py |
| import numpy as np |
| import pandas as pd |
| import torch |
| from torch.utils.data import Dataset |
| from torch.nn.utils.rnn import pad_sequence |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| |
| |
| |
|
|
| def load_bpe_tokenizer(vocab_path: str): |
| """ |
| Load the Bertose BPE tokenizer directly from source. |
| |
| Bypasses downstream_tasks package imports. Adds utils |
| directory to sys.path and imports WURCSBPETokenizer directly. |
| |
| Args: |
| vocab_path: Path to BPE vocabulary JSON file. |
| |
| Returns: |
| WURCSBPETokenizer instance. |
| """ |
| import sys |
|
|
| env_root = os.environ.get("BERTOSE_ROOT") or os.environ.get("BERTOSE_REPO_ROOT") |
| candidate_roots = [] |
| if env_root: |
| candidate_roots.append(Path(env_root).expanduser().resolve()) |
| candidate_roots.extend(Path(__file__).resolve().parents) |
| candidate_roots.append(Path.cwd()) |
|
|
| utils_dir = None |
| for root in candidate_roots: |
| candidate = root / "bert_training_v4" / "downstream_tasks" / "utils" |
| if candidate.exists(): |
| utils_dir = candidate |
| break |
|
|
| if utils_dir is None: |
| utils_dir = Path( |
| "/work/ratul1/supantha/glycan-SD-VS/bert_training_v3/" |
| "v3.1_cluster_training/bert_training_v4/downstream_tasks/utils" |
| ) |
|
|
| utils_dir = str(utils_dir) |
| if utils_dir not in sys.path: |
| sys.path.insert(0, utils_dir) |
|
|
| from wurcs_bpe_tokenizer import WURCSBPETokenizer |
| return WURCSBPETokenizer(vocab_path) |
|
|
|
|
| |
| |
| |
|
|
| class BertintV8Dataset(Dataset): |
| """ |
| Dataset for glycan-protein interaction with cross-attention support. |
| |
| Returns: |
| - BPE-tokenized glycan sequences for live Bertose forward pass |
| - Per-residue ESM-C protein embeddings [Lp, D] (NOT mean-pooled) |
| - Masks for both sides (for cross-attention padding) |
| |
| Args: |
| csv_path: Path to binding data CSV. |
| split_path: Path to glycan-cold splits JSON. |
| split: One of 'train', 'val', 'test'. |
| protein_emb_path: Path to ESM-C embeddings HDF5. |
| vocab_path: Path to BPE vocabulary JSON. |
| max_glycan_length: Maximum glycan sequence length. |
| max_protein_length: Maximum protein residues (truncate longer). |
| target_col: Column name for regression target. |
| """ |
|
|
| def __init__( |
| self, |
| csv_path: str, |
| split_path: str, |
| split: str, |
| protein_emb_path: str, |
| vocab_path: str, |
| max_glycan_length: int = 256, |
| max_protein_length: int = 1024, |
| target_col: str = "target_rank", |
| ): |
| logger.info(f"Loading {split} dataset from {csv_path}") |
|
|
| |
| with open(split_path) as f: |
| splits_data = json.load(f) |
| if "glycan_cold" in splits_data: |
| splits_data = splits_data["glycan_cold"] |
| split_glycans = set(splits_data[split]) |
| logger.info(f" {split}: {len(split_glycans)} glycans in split") |
|
|
| |
| df = pd.read_csv(csv_path) |
| df = df[df["glycan_wurcs"].isin(split_glycans)].copy() |
| df = df.dropna(subset=[target_col]) |
| logger.info(f" {len(df):,} records after split + target filter") |
|
|
| self.records = df.reset_index(drop=True) |
| self.target_col = target_col |
| self.max_protein_length = max_protein_length |
|
|
| |
| logger.info(f" Loading BPE tokenizer from {vocab_path}") |
| self.tokenizer = load_bpe_tokenizer(vocab_path) |
| self.max_glycan_length = max_glycan_length |
|
|
| |
| unique_wurcs = df["glycan_wurcs"].unique() |
| logger.info(f" Pre-tokenizing {len(unique_wurcs)} unique glycans...") |
| self.tokenized_cache: Dict[str, Dict[str, torch.Tensor]] = {} |
| skipped = 0 |
| for wurcs in unique_wurcs: |
| try: |
| tok = self.tokenizer.tokenize( |
| wurcs, max_length=max_glycan_length |
| ) |
| self.tokenized_cache[wurcs] = { |
| "token_ids": torch.tensor( |
| tok["token_ids"], dtype=torch.long |
| ), |
| "attention_mask": torch.tensor( |
| tok["attention_mask"], dtype=torch.long |
| ), |
| "branch_depths": torch.tensor( |
| tok["branch_depths"], dtype=torch.long |
| ), |
| "linkage_types": torch.tensor( |
| tok["linkage_types"], dtype=torch.long |
| ), |
| } |
| except (KeyError, ValueError) as exc: |
| skipped += 1 |
| if skipped <= 5: |
| logger.warning( |
| f" Tokenization failed for WURCS: " |
| f"{wurcs[:60]}... ({exc})" |
| ) |
| if skipped > 0: |
| logger.warning( |
| f" Skipped {skipped} glycans with tokenization errors" |
| ) |
| self.records = self.records[ |
| self.records["glycan_wurcs"].isin(self.tokenized_cache) |
| ].reset_index(drop=True) |
| logger.info( |
| f" {len(self.records):,} records after removing " |
| f"un-tokenizable" |
| ) |
|
|
| |
| logger.info(f" Loading per-residue protein embeddings...") |
| self.protein_embs: Dict[str, torch.Tensor] = {} |
| with h5py.File(protein_emb_path, "r") as f: |
| for key in f.keys(): |
| emb = torch.from_numpy(f[key][:]).float() |
| |
| if emb.dim() == 1: |
| |
| emb = emb.unsqueeze(0) |
| |
| if emb.shape[0] > max_protein_length: |
| emb = emb[:max_protein_length] |
| protein_id = key.replace("|", "/") |
| self.protein_embs[protein_id] = emb |
| logger.info(f" {len(self.protein_embs)} proteins loaded") |
|
|
| |
| lengths = [v.shape[0] for v in self.protein_embs.values()] |
| logger.info( |
| f" Protein lengths: min={min(lengths)}, " |
| f"max={max(lengths)}, mean={np.mean(lengths):.0f}" |
| ) |
|
|
| |
| available = set(self.protein_embs.keys()) |
| has_protein = self.records["protein_id"].isin(available) |
| if not has_protein.all(): |
| missing = (~has_protein).sum() |
| logger.warning( |
| f" {missing} records missing protein embeddings" |
| ) |
| self.records = self.records[has_protein].reset_index(drop=True) |
|
|
| logger.info(f" Final {split} dataset: {len(self.records):,} records") |
|
|
| def __len__(self) -> int: |
| return len(self.records) |
|
|
| def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: |
| row = self.records.iloc[idx] |
| wurcs = row["glycan_wurcs"] |
| protein_id = row["protein_id"] |
|
|
| |
| cached = self.tokenized_cache[wurcs] |
|
|
| |
| protein_emb = self.protein_embs[protein_id] |
|
|
| |
| target = torch.tensor(row[self.target_col], dtype=torch.float) |
|
|
| |
| has_conc = torch.tensor(row.get("has_conc", 0), dtype=torch.float) |
| log_conc = torch.tensor(row.get("log_conc", 0.0), dtype=torch.float) |
|
|
| return { |
| "token_ids": cached["token_ids"], |
| "attention_mask": cached["attention_mask"], |
| "branch_depths": cached["branch_depths"], |
| "linkage_types": cached["linkage_types"], |
| "protein_emb": protein_emb, |
| "protein_length": protein_emb.shape[0], |
| "target": target, |
| "has_conc": has_conc, |
| "log_conc": log_conc, |
| } |
|
|
|
|
| def collate_fn( |
| batch: List[Dict[str, torch.Tensor]], |
| ) -> Dict[str, torch.Tensor]: |
| """ |
| Collate with variable-length protein padding. |
| |
| Glycan sequences are already padded (BPE tokenizer pads to |
| max_glycan_length). Protein sequences need padding to the |
| max length in the batch. |
| """ |
| |
| token_ids = torch.stack([item["token_ids"] for item in batch]) |
| attention_mask = torch.stack( |
| [item["attention_mask"] for item in batch] |
| ).float() |
| branch_depths = torch.stack( |
| [item["branch_depths"] for item in batch] |
| ) |
| linkage_types = torch.stack( |
| [item["linkage_types"] for item in batch] |
| ) |
|
|
| |
| protein_embs = [item["protein_emb"] for item in batch] |
| protein_padded = pad_sequence(protein_embs, batch_first=True) |
|
|
| |
| protein_lengths = [item["protein_length"] for item in batch] |
| max_prot_len = protein_padded.shape[1] |
| protein_mask = torch.zeros(len(batch), max_prot_len) |
| for i, length in enumerate(protein_lengths): |
| protein_mask[i, :length] = 1.0 |
|
|
| |
| targets = torch.stack([item["target"] for item in batch]) |
| has_conc = torch.stack([item["has_conc"] for item in batch]) |
| log_conc = torch.stack([item["log_conc"] for item in batch]) |
|
|
| return { |
| "token_ids": token_ids, |
| "attention_mask": attention_mask, |
| "branch_depths": branch_depths, |
| "linkage_types": linkage_types, |
| "protein_emb": protein_padded, |
| "protein_mask": protein_mask, |
| "target": targets, |
| "has_conc": has_conc, |
| "log_conc": log_conc, |
| } |
|
|