| """ |
| Multimodal Glycan Dataset |
| |
| Combines sequence (WURCS), MS, and 3D structure data for multimodal BERT training. |
| Handles optional modalities (MS and 3D structure). |
| """ |
|
|
| import torch |
| from torch.utils.data import Dataset |
| import pickle |
| from pathlib import Path |
| from typing import Dict, List, Optional, Tuple |
| import numpy as np |
|
|
|
|
| class MultimodalGlycanDataset(Dataset): |
| """ |
| Dataset for multimodal glycan BERT training. |
| |
| Combines: |
| - Sequence tokens (WURCS atomic tokenization) |
| - MS tokens (mass spectrometry peaks, RT, intensity) |
| - 3D structure tokens (VQ-VAE discrete tokens, 4 per residue) |
| |
| Each modality can be enabled/disabled via flags. |
| """ |
| |
| def __init__( |
| self, |
| sequences_path: str, |
| ms_tokens_path: str, |
| structure_data_path: Optional[str] = None, |
| max_seq_length: int = 512, |
| max_ms_length: int = 150, |
| max_mono_length: int = 50, |
| max_struct_tokens: int = 200, |
| max_atoms: int = 300, |
| include_ms: bool = True, |
| include_3d: bool = True, |
| ): |
| """ |
| Initialize multimodal dataset. |
| |
| Args: |
| sequences_path: Path to sequences.pkl (contains token_ids, residue_ids, has_ms, has_3d, monosaccharide_indices) |
| ms_tokens_path: Path to ms_tokens.pkl (contains MS token IDs per WURCS) |
| structure_data_path: Path to training_dataset.pkl (contains VQ-VAE tokens and attention masks) |
| max_seq_length: Maximum sequence length (truncate/pad) |
| max_ms_length: Maximum MS token length (truncate/pad) |
| max_mono_length: Maximum number of monosaccharides (truncate/pad) |
| max_struct_tokens: Maximum structural tokens (truncate/pad) |
| max_atoms: Maximum number of atoms (for cross-attention mask padding) |
| include_ms: Whether to include MS modality |
| include_3d: Whether to include 3D structure modality |
| """ |
| self.max_seq_length = max_seq_length |
| self.max_ms_length = max_ms_length |
| self.max_mono_length = max_mono_length |
| self.max_struct_tokens = max_struct_tokens |
| self.max_atoms = max_atoms |
| self.include_ms = include_ms |
| self.include_3d = include_3d |
| |
| |
| print(f"Loading sequences from {sequences_path}...") |
| with open(sequences_path, 'rb') as f: |
| sequences_raw = pickle.load(f) |
| |
| |
| if isinstance(sequences_raw, dict): |
| self.sequences = [] |
| for wurcs, seq_data in sequences_raw.items(): |
| |
| if not isinstance(seq_data, dict): |
| print(f"Warning: Skipping invalid entry for WURCS: {wurcs[:50]}...") |
| continue |
| |
| if 'token_ids' not in seq_data: |
| print(f"Warning: Skipping entry without token_ids for WURCS: {wurcs[:50]}...") |
| continue |
| |
| |
| if 'wurcs' not in seq_data: |
| seq_data['wurcs'] = wurcs |
| |
| self.sequences.append(seq_data) |
| else: |
| self.sequences = sequences_raw |
| |
| print(f" Loaded {len(self.sequences)} sequences") |
| |
| |
| self.ms_tokens = {} |
| if self.include_ms: |
| print(f"Loading MS tokens from {ms_tokens_path}...") |
| with open(ms_tokens_path, 'rb') as f: |
| self.ms_tokens = pickle.load(f) |
| print(f" Loaded {len(self.ms_tokens)} MS token sets") |
| |
| |
| self.structure_data = {} |
| if self.include_3d and structure_data_path: |
| struct_path = Path(structure_data_path) |
| if struct_path.exists(): |
| print(f"Loading 3D structure data from {structure_data_path}...") |
| with open(structure_data_path, 'rb') as f: |
| struct_pkl = pickle.load(f) |
| |
| |
| if isinstance(struct_pkl, dict) and 'full_multimodal' in struct_pkl: |
| samples = struct_pkl['full_multimodal'] |
| self.structure_data = {s['wurcs']: s for s in samples} |
| else: |
| self.structure_data = {s['wurcs']: s for s in struct_pkl} |
| |
| print(f" Loaded {len(self.structure_data)} structure samples") |
| else: |
| print(f" Warning: Structure data file not found at {structure_data_path}") |
| print(f" Continuing without 3D structure modality...") |
| |
| |
| self._compute_stats() |
| |
| def _compute_stats(self): |
| """Compute dataset statistics.""" |
| |
| ms_available = 0 |
| struct_available = 0 |
| |
| for s in self.sequences: |
| wurcs = s.get('wurcs', '') |
| if wurcs in self.ms_tokens: |
| ms_available += 1 |
| if wurcs in self.structure_data: |
| struct_available += 1 |
| |
| self.stats = { |
| 'total': len(self.sequences), |
| 'with_ms_available': ms_available, |
| 'with_3d_available': struct_available, |
| 'with_ms_tokens': len(self.ms_tokens), |
| 'with_structure_tokens': len(self.structure_data), |
| } |
| |
| print(f"\nDataset Statistics:") |
| print(f" Total sequences: {self.stats['total']:,}") |
| print(f" With MS data: {self.stats['with_ms_available']:,} ({100*self.stats['with_ms_available']/self.stats['total']:.2f}%)") |
| print(f" With 3D data: {self.stats['with_3d_available']:,} ({100*self.stats['with_3d_available']/self.stats['total']:.2f}%)") |
| print(f" Include MS: {self.include_ms}") |
| print(f" Include 3D: {self.include_3d}") |
| print() |
| |
| def __len__(self) -> int: |
| return len(self.sequences) |
| |
| def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: |
| """ |
| Get a single multimodal sample. |
| |
| Returns: |
| Dictionary containing: |
| - seq_token_ids: Sequence token IDs (padded/truncated) |
| - seq_attention_mask: Sequence attention mask |
| - seq_residue_ids: Residue position IDs for sequence tokens |
| - ms_token_ids: MS token IDs (padded/truncated, or empty if no MS) |
| - ms_attention_mask: MS attention mask |
| - ms_residue_ids: Residue IDs for MS tokens (all -2 for global) |
| - mono_indices: Monosaccharide indices (padded/truncated) |
| - mono_residue_ids: Residue IDs for each monosaccharide |
| - has_ms: Whether this sample has MS data |
| - has_3d: Whether this sample has 3D data (future) |
| - has_residue_error: Whether this sample has [RESIDUE_ERROR] tokens |
| """ |
| seq_data = self.sequences[idx] |
| wurcs = seq_data['wurcs'] |
| |
| |
| seq_token_ids = seq_data['token_ids'] |
| seq_residue_ids = seq_data.get('residue_ids', [-1] * len(seq_token_ids)) |
| |
| seq_branch_depths = seq_data.get('branch_depths', [0] * len(seq_token_ids)) |
| seq_linkage_types = seq_data.get('linkage_types', [0] * len(seq_token_ids)) |
| |
| |
| if len(seq_token_ids) > self.max_seq_length: |
| seq_token_ids = seq_token_ids[:self.max_seq_length] |
| seq_residue_ids = seq_residue_ids[:self.max_seq_length] |
| seq_branch_depths = seq_branch_depths[:self.max_seq_length] |
| seq_linkage_types = seq_linkage_types[:self.max_seq_length] |
| |
| seq_len = len(seq_token_ids) |
| seq_attention_mask = [1] * seq_len |
| |
| |
| padding_len = self.max_seq_length - seq_len |
| seq_token_ids = seq_token_ids + [0] * padding_len |
| seq_residue_ids = seq_residue_ids + [-1] * padding_len |
| seq_branch_depths = seq_branch_depths + [0] * padding_len |
| seq_linkage_types = seq_linkage_types + [0] * padding_len |
| |
| seq_attention_mask = seq_attention_mask + [0] * padding_len |
| |
| |
| dist_labels = seq_data.get('distance_matrix', None) |
| if dist_labels is not None: |
| |
| |
| |
| |
| |
| |
| padded_dist = [[-1] * self.max_seq_length for _ in range(self.max_seq_length)] |
| |
| |
| current_len = len(dist_labels) |
| |
| trunc_len = min(current_len, self.max_seq_length) |
| |
| for i in range(trunc_len): |
| row = dist_labels[i] |
| valid_row_len = min(len(row), self.max_seq_length) |
| for j in range(valid_row_len): |
| padded_dist[i][j] = row[j] |
| |
| dist_labels = torch.tensor(padded_dist, dtype=torch.float) |
| else: |
| |
| dist_labels = torch.full((self.max_seq_length, self.max_seq_length), -1.0) |
| |
| |
| has_ms = False |
| ms_token_ids = [] |
| ms_residue_ids = [] |
| ms_attention_mask = [] |
| |
| if self.include_ms and wurcs in self.ms_tokens: |
| has_ms = True |
| ms_data = self.ms_tokens[wurcs] |
| |
| |
| if isinstance(ms_data, dict) and 'ms_token_ids' in ms_data: |
| ms_token_ids = ms_data['ms_token_ids'] |
| elif isinstance(ms_data, str): |
| |
| has_ms = False |
| ms_token_ids = [] |
| elif isinstance(ms_data, list): |
| |
| ms_token_ids = ms_data |
| else: |
| |
| has_ms = False |
| ms_token_ids = [] |
| |
| |
| if not isinstance(ms_token_ids, list): |
| has_ms = False |
| ms_token_ids = [] |
| elif len(ms_token_ids) > 0 and isinstance(ms_token_ids[0], str): |
| |
| has_ms = False |
| ms_token_ids = [] |
| |
| |
| if has_ms and len(ms_token_ids) > 0: |
| if len(ms_token_ids) > self.max_ms_length: |
| ms_token_ids = ms_token_ids[:self.max_ms_length] |
| |
| ms_len = len(ms_token_ids) |
| ms_attention_mask = [1] * ms_len |
| |
| |
| ms_residue_ids = [-2] * ms_len |
| |
| |
| padding_len = self.max_ms_length - ms_len |
| ms_token_ids = ms_token_ids + [0] * padding_len |
| ms_residue_ids = ms_residue_ids + [-1] * padding_len |
| ms_attention_mask = ms_attention_mask + [0] * padding_len |
| |
| |
| if len(ms_token_ids) != self.max_ms_length: |
| has_ms = False |
| ms_token_ids = [0] * self.max_ms_length |
| ms_residue_ids = [-1] * self.max_ms_length |
| ms_attention_mask = [0] * self.max_ms_length |
| |
| |
| mono_indices = seq_data.get('monosaccharide_indices', []) |
| mono_residue_ids = seq_data.get('monosaccharide_residue_ids', []) |
| |
| |
| if not isinstance(mono_indices, list): |
| mono_indices = [] |
| mono_residue_ids = [] |
| elif len(mono_indices) > 0: |
| |
| validated_indices = [] |
| validated_residue_ids = [] |
| for i, idx in enumerate(mono_indices): |
| if isinstance(idx, (int, np.integer)): |
| validated_indices.append(int(idx)) |
| if i < len(mono_residue_ids) and isinstance(mono_residue_ids[i], (int, np.integer)): |
| validated_residue_ids.append(int(mono_residue_ids[i])) |
| else: |
| validated_residue_ids.append(-1) |
| elif isinstance(idx, str): |
| |
| try: |
| validated_indices.append(int(idx)) |
| if i < len(mono_residue_ids): |
| try: |
| validated_residue_ids.append(int(mono_residue_ids[i])) |
| except (ValueError, TypeError): |
| validated_residue_ids.append(-1) |
| else: |
| validated_residue_ids.append(-1) |
| except (ValueError, TypeError): |
| |
| continue |
| |
| mono_indices = validated_indices |
| mono_residue_ids = validated_residue_ids |
| |
| |
| if len(mono_indices) > self.max_mono_length: |
| mono_indices = mono_indices[:self.max_mono_length] |
| mono_residue_ids = mono_residue_ids[:self.max_mono_length] |
| |
| mono_len = len(mono_indices) |
| padding_len = self.max_mono_length - mono_len |
| mono_indices = mono_indices + [0] * padding_len |
| mono_residue_ids = mono_residue_ids + [-1] * padding_len |
| |
| |
| has_3d = False |
| struct_token_ids = [] |
| struct_attention_mask = [] |
| struct_residue_ids = [] |
| |
| if self.include_3d and wurcs in self.structure_data: |
| has_3d = True |
| struct_sample = self.structure_data[wurcs] |
| |
| |
| |
| wurcs_to_graphml = struct_sample.get('wurcs_to_graphml_mapping', {}) |
| |
| |
| graphml_to_wurcs = {v: k for k, v in wurcs_to_graphml.items()} |
| |
| |
| struct_tokens_per_residue = struct_sample['structural_tokens_per_residue'] |
| for graphml_idx, residue_tokens in enumerate(struct_tokens_per_residue): |
| struct_token_ids.extend(residue_tokens) |
| |
| |
| wurcs_res_id = graphml_to_wurcs.get(graphml_idx, -1) |
| struct_residue_ids.extend([wurcs_res_id] * len(residue_tokens)) |
| |
| |
| if len(struct_token_ids) > self.max_struct_tokens: |
| struct_token_ids = struct_token_ids[:self.max_struct_tokens] |
| struct_residue_ids = struct_residue_ids[:self.max_struct_tokens] |
| |
| struct_len = len(struct_token_ids) |
| struct_attention_mask = [1] * struct_len |
| |
| |
| padding_len = self.max_struct_tokens - struct_len |
| struct_token_ids = struct_token_ids + [0] * padding_len |
| struct_residue_ids = struct_residue_ids + [-1] * padding_len |
| struct_attention_mask = struct_attention_mask + [0] * padding_len |
| |
| |
| if len(struct_token_ids) != self.max_struct_tokens: |
| has_3d = False |
| struct_token_ids = [0] * self.max_struct_tokens |
| struct_residue_ids = [-1] * self.max_struct_tokens |
| struct_attention_mask = [0] * self.max_struct_tokens |
| |
| has_residue_error = seq_data.get('has_residue_error', False) |
| |
| |
| result = { |
| 'seq_token_ids': torch.tensor(seq_token_ids, dtype=torch.long), |
| 'seq_attention_mask': torch.tensor(seq_attention_mask, dtype=torch.long), |
| 'seq_residue_ids': torch.tensor(seq_residue_ids, dtype=torch.long), |
| 'seq_branch_depths': torch.tensor(seq_branch_depths, dtype=torch.long), |
| 'seq_linkage_types': torch.tensor(seq_linkage_types, dtype=torch.long), |
| 'dist_labels': dist_labels, |
| |
| |
| 'ms_token_ids': torch.tensor(ms_token_ids, dtype=torch.long), |
| 'ms_attention_mask': torch.tensor(ms_attention_mask, dtype=torch.long), |
| 'ms_residue_ids': torch.tensor(ms_residue_ids, dtype=torch.long), |
| 'struct_token_ids': torch.tensor(struct_token_ids, dtype=torch.long), |
| 'struct_attention_mask': torch.tensor(struct_attention_mask, dtype=torch.long), |
| 'struct_residue_ids': torch.tensor(struct_residue_ids, dtype=torch.long), |
| 'mono_indices': torch.tensor(mono_indices, dtype=torch.long), |
| 'mono_residue_ids': torch.tensor(mono_residue_ids, dtype=torch.long), |
| 'has_ms': torch.tensor(has_ms, dtype=torch.bool), |
| 'has_3d': torch.tensor(has_3d, dtype=torch.bool), |
| 'has_residue_error': torch.tensor(has_residue_error, dtype=torch.bool), |
| } |
| |
| |
| |
| |
| return result |
|
|
|
|
| def collate_fn(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: |
| """ |
| Collate function for batching multimodal samples. |
| |
| Args: |
| batch: List of samples from __getitem__ |
| |
| Returns: |
| Batched tensors |
| """ |
| result = { |
| 'seq_token_ids': torch.stack([item['seq_token_ids'] for item in batch]), |
| 'seq_attention_mask': torch.stack([item['seq_attention_mask'] for item in batch]), |
| 'seq_residue_ids': torch.stack([item['seq_residue_ids'] for item in batch]), |
| 'seq_branch_depths': torch.stack([item['seq_branch_depths'] for item in batch]), |
| 'seq_linkage_types': torch.stack([item['seq_linkage_types'] for item in batch]), |
| 'ms_token_ids': torch.stack([item['ms_token_ids'] for item in batch]), |
| 'ms_attention_mask': torch.stack([item['ms_attention_mask'] for item in batch]), |
| 'ms_residue_ids': torch.stack([item['ms_residue_ids'] for item in batch]), |
| 'struct_token_ids': torch.stack([item['struct_token_ids'] for item in batch]), |
| 'struct_attention_mask': torch.stack([item['struct_attention_mask'] for item in batch]), |
| 'struct_residue_ids': torch.stack([item['struct_residue_ids'] for item in batch]), |
| 'mono_indices': torch.stack([item['mono_indices'] for item in batch]), |
| 'mono_residue_ids': torch.stack([item['mono_residue_ids'] for item in batch]), |
| 'has_ms': torch.stack([item['has_ms'] for item in batch]), |
| 'has_3d': torch.stack([item['has_3d'] for item in batch]), |
| 'has_residue_error': torch.stack([item['has_residue_error'] for item in batch]), |
| 'dist_labels': torch.stack([item['dist_labels'] for item in batch]), |
| } |
| |
| return result |
|
|
|
|
| def create_multimodal_dataloaders( |
| sequences_path: str, |
| ms_tokens_path: str, |
| structure_data_path: str, |
| batch_size: int = 64, |
| num_workers: int = 4, |
| max_seq_length: int = 512, |
| max_ms_length: int = 150, |
| max_struct_length: int = 200, |
| train_split: float = 0.8, |
| ): |
| """ |
| Create train and validation dataloaders for multimodal training. |
| |
| Args: |
| sequences_path: Path to sequences.pkl |
| ms_tokens_path: Path to ms_tokens.pkl |
| structure_data_path: Path to training_dataset.pkl (VQ-VAE tokens) |
| batch_size: Batch size |
| num_workers: Number of data loading workers |
| max_seq_length: Maximum sequence length |
| max_ms_length: Maximum MS token length |
| max_struct_length: Maximum structural token length |
| train_split: Fraction of data for training (default 0.8 = 80/20 split) |
| |
| Returns: |
| train_loader, val_loader |
| """ |
| from torch.utils.data import DataLoader, random_split |
| |
| |
| full_dataset = MultimodalGlycanDataset( |
| sequences_path=sequences_path, |
| ms_tokens_path=ms_tokens_path, |
| structure_data_path=structure_data_path, |
| max_seq_length=max_seq_length, |
| max_ms_length=max_ms_length, |
| max_struct_tokens=max_struct_length, |
| include_ms=True, |
| include_3d=True, |
| ) |
| |
| |
| total_size = len(full_dataset) |
| train_size = int(train_split * total_size) |
| val_size = total_size - train_size |
| |
| train_dataset, val_dataset = random_split( |
| full_dataset, |
| [train_size, val_size], |
| generator=torch.Generator().manual_seed(42) |
| ) |
| |
| |
| train_loader = DataLoader( |
| train_dataset, |
| batch_size=batch_size, |
| shuffle=True, |
| num_workers=num_workers, |
| collate_fn=collate_fn, |
| pin_memory=True, |
| ) |
| |
| val_loader = DataLoader( |
| val_dataset, |
| batch_size=batch_size, |
| shuffle=False, |
| num_workers=num_workers, |
| collate_fn=collate_fn, |
| pin_memory=True, |
| ) |
| |
| print(f"Created dataloaders: {train_size} train, {val_size} val") |
| |
| return train_loader, val_loader |
|
|
|
|
| if __name__ == "__main__": |
| |
| import sys |
| from pathlib import Path |
| |
| base_path = Path(__file__).parent.parent / "data" |
| |
| dataset = MultimodalGlycanDataset( |
| sequences_path=str(base_path / "sequences.pkl"), |
| ms_tokens_path=str(base_path / "ms_tokens.pkl"), |
| structure_data_path=str(Path(__file__).parent.parent.parent / "structure/cluster_upload/files/multimodal_training_package/training_dataset.pkl"), |
| max_seq_length=512, |
| max_ms_length=150, |
| max_struct_tokens=200, |
| max_atoms=300, |
| include_ms=True, |
| include_3d=True, |
| ) |
| |
| print("="*80) |
| print("Testing Dataset") |
| print("="*80) |
| |
| |
| sample = dataset[0] |
| print(f"\nSample 0:") |
| for key, value in sample.items(): |
| if isinstance(value, torch.Tensor): |
| print(f" {key}: shape={value.shape}, dtype={value.dtype}") |
| if key in ['seq_token_ids', 'ms_token_ids', 'struct_token_ids']: |
| non_zero = (value != 0).sum().item() |
| print(f" Non-padding tokens: {non_zero}") |
| else: |
| print(f" {key}: {value}") |
| |
| |
| print(f"\n{'='*80}") |
| print("Testing Batch") |
| print("="*80) |
| |
| from torch.utils.data import DataLoader |
| |
| dataloader = DataLoader( |
| dataset, |
| batch_size=4, |
| shuffle=False, |
| collate_fn=collate_fn, |
| ) |
| |
| batch = next(iter(dataloader)) |
| print(f"\nBatch shapes:") |
| for key, value in batch.items(): |
| print(f" {key}: {value.shape}") |
| |
| print(f"\nBatch MS availability:") |
| print(f" Samples with MS: {batch['has_ms'].sum().item()}/{len(batch['has_ms'])}") |
|
|
|
|