""" Simplified Dataset for Glycan Classification Fine-tuning Works with clean benchmark CSV that has WURCS column. Supports both atomic (v3) and BPE (v4) tokenization. """ import torch from torch.utils.data import Dataset import pandas as pd from typing import Dict, List, Optional, Tuple import logging import json from .tokenizer import WURCSTokenizer, create_tokenizer # Try to import BPE tokenizer (available in v4) try: from .wurcs_bpe_tokenizer import WURCSBPETokenizer HAS_BPE = True except ImportError: HAS_BPE = False logger = logging.getLogger(__name__) def load_tokenizer(vocab_path: str): """ Load tokenizer based on vocabulary file type. Args: vocab_path: Path to vocabulary JSON file Returns: Either WURCSTokenizer (atomic) or WURCSBPETokenizer (BPE) """ with open(vocab_path, 'r') as f: vocab = json.load(f) # Check if this is a BPE vocabulary (has 'merges' field) if 'merges' in vocab: if not HAS_BPE: raise ImportError("BPE vocabulary detected but WURCSBPETokenizer not available") logger.info(f"Loading BPE tokenizer (vocab_size={vocab['metadata']['vocab_size']})") return WURCSBPETokenizer(vocab_path) else: logger.info(f"Loading atomic tokenizer (vocab_size={vocab.get('metadata', {}).get('vocab_size', 167)})") return WURCSTokenizer(vocab_path) class GlycanClassificationDataset(Dataset): """ Dataset for glycan classification tasks. Expects a CSV with columns: - target: IUPAC representation of glycan - wurcs: WURCS representation (required) - {task_name}: Label column (e.g., 'species', 'phylum') - split: 'train', 'validation', or 'test' """ def __init__( self, csv_path: str, task: str, split: str, vocab_path: str, max_length: int = 256, # Reduced default for BPE (was 512) valid_classes: List[str] = None, ): """ Initialize dataset. Args: csv_path: Path to CSV file task: Task name (column name for labels) split: One of 'train', 'validation', 'test' vocab_path: Path to vocabulary.json max_length: Maximum sequence length valid_classes: Optional list of valid classes to filter to """ self.task = task self.split = split self.max_length = max_length # Load tokenizer (auto-detects BPE vs atomic based on vocab file) self.tokenizer = load_tokenizer(vocab_path) # Load data df = pd.read_csv(csv_path) # Filter by split # Handle different column naming conventions: # - Classification CSV: 'train', 'validation', 'test' columns (binary 0/1) # - Immunogenicity/Link CSV: 'train', 'valid', 'test' columns (binary 0/1) # - Some files might have a single 'split' column with string values # Map 'validation' to 'valid' if needed split_col = split if split == 'validation' and 'validation' not in df.columns and 'valid' in df.columns: split_col = 'valid' if split_col in df.columns: # Binary column for this split if df[split_col].dtype == 'bool': self.df = df[df[split_col]].copy() else: # int64 or similar self.df = df[df[split_col] == 1].copy() elif 'split' in df.columns: # Single 'split' column with string values self.df = df[df['split'] == split].copy() else: raise ValueError(f"Cannot find split column '{split}' or '{split_col}' or 'split' in CSV") # Filter to only samples with WURCS initial_count = len(self.df) self.df = self.df[self.df['wurcs'].notna()].copy() removed = initial_count - len(self.df) if removed > 0: logger.info(f"Removed {removed} samples without WURCS") # Get unique labels and create mapping self.df = self.df[self.df[task].notna()].copy() # Apply valid_classes filter if provided (for strict filtering mode) if valid_classes is not None: before_filter = len(self.df) self.df = self.df[self.df[task].isin(valid_classes)].copy() filtered_out = before_filter - len(self.df) if filtered_out > 0: logger.info(f" Filtered out {filtered_out} samples (classes not in all splits)") self.unique_labels = sorted([c for c in valid_classes if c in self.df[task].values]) else: self.unique_labels = sorted(self.df[task].unique()) self.label_to_id = {label: i for i, label in enumerate(self.unique_labels)} self.id_to_label = {i: label for label, i in self.label_to_id.items()} # Log info logger.info(f"Loaded {len(self.df)} samples for {split} split") logger.info(f" Task: {task}") logger.info(f" Classes: {len(self.unique_labels)}") def __len__(self) -> int: return len(self.df) def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: row = self.df.iloc[idx] # Tokenize WURCS wurcs = row['wurcs'] tokenized = self.tokenizer.tokenize(wurcs, max_length=self.max_length) # Get label label = self.label_to_id[row[self.task]] # Handle fields that may not exist in BPE tokenizer num_residues = tokenized.get('num_residues', max(tokenized.get('residue_ids', [0])) + 1) is_branched = tokenized.get('is_branched', '[BRANCH_OPEN]' in tokenized.get('tokens', [])) return { 'token_ids': torch.tensor(tokenized['token_ids'], dtype=torch.long), 'attention_mask': torch.tensor(tokenized['attention_mask'], dtype=torch.long), 'residue_ids': torch.tensor(tokenized['residue_ids'], dtype=torch.long), 'branch_depths': torch.tensor(tokenized['branch_depths'], dtype=torch.long), 'linkage_types': torch.tensor(tokenized['linkage_types'], dtype=torch.long), 'num_residues': num_residues, 'is_branched': is_branched, 'label': torch.tensor(label, dtype=torch.long), } def get_class_weights(self) -> torch.Tensor: """ Compute class weights for imbalanced data. Returns: Tensor of class weights (inverse frequency) """ class_counts = self.df[self.task].value_counts() total = len(self.df) n_classes = len(self.unique_labels) weights = [] for label in self.unique_labels: count = class_counts.get(label, 1) weight = total / (n_classes * count) weights.append(weight) return torch.tensor(weights, dtype=torch.float) def compute_valid_classes( csv_path: str, task: str, min_samples: int = 1, ) -> List[str]: """ Compute classes that are present in all splits (train, val, test) with at least min_samples in each split. This is used for 'strict' filtering mode (GlycanML approach). Args: csv_path: Path to CSV file task: Task column name min_samples: Minimum samples per class per split (default 1) Returns: List of valid class names """ df = pd.read_csv(csv_path) # Only consider samples with WURCS df = df[df['wurcs'].notna()] df = df[df[task].notna()] # Get split DataFrames # Handle different column naming conventions: # - Classification CSV: 'train', 'validation', 'test' # - Immunogenicity/Link CSV: 'train', 'valid', 'test' if 'train' in df.columns: train_df = df[df['train'] == 1] else: train_df = df[df['split'] == 'train'] if 'validation' in df.columns: val_df = df[df['validation'] == 1] elif 'valid' in df.columns: val_df = df[df['valid'] == 1] else: val_df = df[df['split'] == 'validation'] if 'test' in df.columns: test_df = df[df['test'] == 1] else: test_df = df[df['split'] == 'test'] # Get class counts per split train_counts = train_df[task].value_counts() val_counts = val_df[task].value_counts() test_counts = test_df[task].value_counts() # Get classes with >= min_samples in each split train_classes = set(train_counts[train_counts >= min_samples].index) val_classes = set(val_counts[val_counts >= min_samples].index) test_classes = set(test_counts[test_counts >= min_samples].index) # Classes must meet min_samples threshold in all splits valid_classes = sorted(train_classes & val_classes & test_classes) all_classes = set(train_counts.index) | set(val_counts.index) | set(test_counts.index) logger.info(f"Computing valid classes for {task} (min_samples={min_samples}):") logger.info(f" Train classes (>={min_samples}): {len(train_classes)}") logger.info(f" Val classes (>={min_samples}): {len(val_classes)}") logger.info(f" Test classes (>={min_samples}): {len(test_classes)}") logger.info(f" Valid (in all splits): {len(valid_classes)}") logger.info(f" Excluded: {len(all_classes) - len(valid_classes)}") return valid_classes def filter_to_valid_classes( train_df: pd.DataFrame, val_df: pd.DataFrame, test_df: pd.DataFrame, task: str, ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, List[str]]: """ Filter datasets to only include classes present in all splits. Args: train_df: Training data val_df: Validation data test_df: Test data task: Task column name Returns: Filtered (train_df, val_df, test_df, valid_classes) """ train_classes = set(train_df[task].dropna().unique()) val_classes = set(val_df[task].dropna().unique()) test_classes = set(test_df[task].dropna().unique()) # Classes must be in all splits valid_classes = sorted(train_classes & val_classes & test_classes) logger.info(f"Filtering classes for {task}:") logger.info(f" Train classes: {len(train_classes)}") logger.info(f" Val classes: {len(val_classes)}") logger.info(f" Test classes: {len(test_classes)}") logger.info(f" Valid (in all): {len(valid_classes)}") train_df = train_df[train_df[task].isin(valid_classes)].copy() val_df = val_df[val_df[task].isin(valid_classes)].copy() test_df = test_df[test_df[task].isin(valid_classes)].copy() return train_df, val_df, test_df, valid_classes def create_dataloaders( csv_path: str, task: str, vocab_path: str, batch_size: int = 64, max_length: int = 512, num_workers: int = 4, ) -> Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader, torch.utils.data.DataLoader]: """ Create train, validation, and test dataloaders. Args: csv_path: Path to CSV file task: Task name vocab_path: Path to vocabulary.json batch_size: Batch size max_length: Maximum sequence length num_workers: Number of data loading workers Returns: (train_loader, val_loader, test_loader) """ # Create datasets train_dataset = GlycanClassificationDataset( csv_path, task, 'train', vocab_path, max_length ) val_dataset = GlycanClassificationDataset( csv_path, task, 'validation', vocab_path, max_length ) test_dataset = GlycanClassificationDataset( csv_path, task, 'test', vocab_path, max_length ) # Create dataloaders train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, ) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True, ) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True, ) return train_loader, val_loader, test_loader