| """ |
| Simplified Dataset for Glycan Classification Fine-tuning |
| |
| Works with clean benchmark CSV that has WURCS column. |
| Supports both atomic (v3) and BPE (v4) tokenization. |
| """ |
|
|
| import torch |
| from torch.utils.data import Dataset |
| import pandas as pd |
| from typing import Dict, List, Optional, Tuple |
| import logging |
| import json |
|
|
| from .tokenizer import WURCSTokenizer, create_tokenizer |
|
|
| |
| try: |
| from .wurcs_bpe_tokenizer import WURCSBPETokenizer |
| HAS_BPE = True |
| except ImportError: |
| HAS_BPE = False |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def load_tokenizer(vocab_path: str): |
| """ |
| Load tokenizer based on vocabulary file type. |
| |
| Args: |
| vocab_path: Path to vocabulary JSON file |
| |
| Returns: |
| Either WURCSTokenizer (atomic) or WURCSBPETokenizer (BPE) |
| """ |
| with open(vocab_path, 'r') as f: |
| vocab = json.load(f) |
| |
| |
| if 'merges' in vocab: |
| if not HAS_BPE: |
| raise ImportError("BPE vocabulary detected but WURCSBPETokenizer not available") |
| logger.info(f"Loading BPE tokenizer (vocab_size={vocab['metadata']['vocab_size']})") |
| return WURCSBPETokenizer(vocab_path) |
| else: |
| logger.info(f"Loading atomic tokenizer (vocab_size={vocab.get('metadata', {}).get('vocab_size', 167)})") |
| return WURCSTokenizer(vocab_path) |
|
|
|
|
| class GlycanClassificationDataset(Dataset): |
| """ |
| Dataset for glycan classification tasks. |
| |
| Expects a CSV with columns: |
| - target: IUPAC representation of glycan |
| - wurcs: WURCS representation (required) |
| - {task_name}: Label column (e.g., 'species', 'phylum') |
| - split: 'train', 'validation', or 'test' |
| """ |
| |
| def __init__( |
| self, |
| csv_path: str, |
| task: str, |
| split: str, |
| vocab_path: str, |
| max_length: int = 256, |
| valid_classes: List[str] = None, |
| ): |
| """ |
| Initialize dataset. |
| |
| Args: |
| csv_path: Path to CSV file |
| task: Task name (column name for labels) |
| split: One of 'train', 'validation', 'test' |
| vocab_path: Path to vocabulary.json |
| max_length: Maximum sequence length |
| valid_classes: Optional list of valid classes to filter to |
| """ |
| self.task = task |
| self.split = split |
| self.max_length = max_length |
| |
| |
| self.tokenizer = load_tokenizer(vocab_path) |
| |
| |
| df = pd.read_csv(csv_path) |
| |
| |
| |
| |
| |
| |
| |
| |
| split_col = split |
| if split == 'validation' and 'validation' not in df.columns and 'valid' in df.columns: |
| split_col = 'valid' |
| |
| if split_col in df.columns: |
| |
| if df[split_col].dtype == 'bool': |
| self.df = df[df[split_col]].copy() |
| else: |
| self.df = df[df[split_col] == 1].copy() |
| elif 'split' in df.columns: |
| |
| self.df = df[df['split'] == split].copy() |
| else: |
| raise ValueError(f"Cannot find split column '{split}' or '{split_col}' or 'split' in CSV") |
| |
| |
| initial_count = len(self.df) |
| self.df = self.df[self.df['wurcs'].notna()].copy() |
| removed = initial_count - len(self.df) |
| if removed > 0: |
| logger.info(f"Removed {removed} samples without WURCS") |
| |
| |
| self.df = self.df[self.df[task].notna()].copy() |
| |
| |
| if valid_classes is not None: |
| before_filter = len(self.df) |
| self.df = self.df[self.df[task].isin(valid_classes)].copy() |
| filtered_out = before_filter - len(self.df) |
| if filtered_out > 0: |
| logger.info(f" Filtered out {filtered_out} samples (classes not in all splits)") |
| self.unique_labels = sorted([c for c in valid_classes if c in self.df[task].values]) |
| else: |
| self.unique_labels = sorted(self.df[task].unique()) |
| |
| self.label_to_id = {label: i for i, label in enumerate(self.unique_labels)} |
| self.id_to_label = {i: label for label, i in self.label_to_id.items()} |
| |
| |
| logger.info(f"Loaded {len(self.df)} samples for {split} split") |
| logger.info(f" Task: {task}") |
| logger.info(f" Classes: {len(self.unique_labels)}") |
| |
| def __len__(self) -> int: |
| return len(self.df) |
| |
| def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: |
| row = self.df.iloc[idx] |
| |
| |
| wurcs = row['wurcs'] |
| tokenized = self.tokenizer.tokenize(wurcs, max_length=self.max_length) |
| |
| |
| label = self.label_to_id[row[self.task]] |
| |
| |
| num_residues = tokenized.get('num_residues', max(tokenized.get('residue_ids', [0])) + 1) |
| is_branched = tokenized.get('is_branched', '[BRANCH_OPEN]' in tokenized.get('tokens', [])) |
| |
| return { |
| 'token_ids': torch.tensor(tokenized['token_ids'], dtype=torch.long), |
| 'attention_mask': torch.tensor(tokenized['attention_mask'], dtype=torch.long), |
| 'residue_ids': torch.tensor(tokenized['residue_ids'], dtype=torch.long), |
| 'branch_depths': torch.tensor(tokenized['branch_depths'], dtype=torch.long), |
| 'linkage_types': torch.tensor(tokenized['linkage_types'], dtype=torch.long), |
| 'num_residues': num_residues, |
| 'is_branched': is_branched, |
| 'label': torch.tensor(label, dtype=torch.long), |
| } |
| |
| def get_class_weights(self) -> torch.Tensor: |
| """ |
| Compute class weights for imbalanced data. |
| |
| Returns: |
| Tensor of class weights (inverse frequency) |
| """ |
| class_counts = self.df[self.task].value_counts() |
| total = len(self.df) |
| n_classes = len(self.unique_labels) |
| |
| weights = [] |
| for label in self.unique_labels: |
| count = class_counts.get(label, 1) |
| weight = total / (n_classes * count) |
| weights.append(weight) |
| |
| return torch.tensor(weights, dtype=torch.float) |
|
|
|
|
| def compute_valid_classes( |
| csv_path: str, |
| task: str, |
| min_samples: int = 1, |
| ) -> List[str]: |
| """ |
| Compute classes that are present in all splits (train, val, test) |
| with at least min_samples in each split. |
| |
| This is used for 'strict' filtering mode (GlycanML approach). |
| |
| Args: |
| csv_path: Path to CSV file |
| task: Task column name |
| min_samples: Minimum samples per class per split (default 1) |
| |
| Returns: |
| List of valid class names |
| """ |
| df = pd.read_csv(csv_path) |
| |
| |
| df = df[df['wurcs'].notna()] |
| df = df[df[task].notna()] |
| |
| |
| |
| |
| |
| if 'train' in df.columns: |
| train_df = df[df['train'] == 1] |
| else: |
| train_df = df[df['split'] == 'train'] |
| |
| if 'validation' in df.columns: |
| val_df = df[df['validation'] == 1] |
| elif 'valid' in df.columns: |
| val_df = df[df['valid'] == 1] |
| else: |
| val_df = df[df['split'] == 'validation'] |
| |
| if 'test' in df.columns: |
| test_df = df[df['test'] == 1] |
| else: |
| test_df = df[df['split'] == 'test'] |
| |
| |
| train_counts = train_df[task].value_counts() |
| val_counts = val_df[task].value_counts() |
| test_counts = test_df[task].value_counts() |
| |
| |
| train_classes = set(train_counts[train_counts >= min_samples].index) |
| val_classes = set(val_counts[val_counts >= min_samples].index) |
| test_classes = set(test_counts[test_counts >= min_samples].index) |
| |
| |
| valid_classes = sorted(train_classes & val_classes & test_classes) |
| |
| all_classes = set(train_counts.index) | set(val_counts.index) | set(test_counts.index) |
| |
| logger.info(f"Computing valid classes for {task} (min_samples={min_samples}):") |
| logger.info(f" Train classes (>={min_samples}): {len(train_classes)}") |
| logger.info(f" Val classes (>={min_samples}): {len(val_classes)}") |
| logger.info(f" Test classes (>={min_samples}): {len(test_classes)}") |
| logger.info(f" Valid (in all splits): {len(valid_classes)}") |
| logger.info(f" Excluded: {len(all_classes) - len(valid_classes)}") |
| |
| return valid_classes |
|
|
|
|
| def filter_to_valid_classes( |
| train_df: pd.DataFrame, |
| val_df: pd.DataFrame, |
| test_df: pd.DataFrame, |
| task: str, |
| ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, List[str]]: |
| """ |
| Filter datasets to only include classes present in all splits. |
| |
| Args: |
| train_df: Training data |
| val_df: Validation data |
| test_df: Test data |
| task: Task column name |
| |
| Returns: |
| Filtered (train_df, val_df, test_df, valid_classes) |
| """ |
| train_classes = set(train_df[task].dropna().unique()) |
| val_classes = set(val_df[task].dropna().unique()) |
| test_classes = set(test_df[task].dropna().unique()) |
| |
| |
| valid_classes = sorted(train_classes & val_classes & test_classes) |
| |
| logger.info(f"Filtering classes for {task}:") |
| logger.info(f" Train classes: {len(train_classes)}") |
| logger.info(f" Val classes: {len(val_classes)}") |
| logger.info(f" Test classes: {len(test_classes)}") |
| logger.info(f" Valid (in all): {len(valid_classes)}") |
| |
| train_df = train_df[train_df[task].isin(valid_classes)].copy() |
| val_df = val_df[val_df[task].isin(valid_classes)].copy() |
| test_df = test_df[test_df[task].isin(valid_classes)].copy() |
| |
| return train_df, val_df, test_df, valid_classes |
|
|
|
|
| def create_dataloaders( |
| csv_path: str, |
| task: str, |
| vocab_path: str, |
| batch_size: int = 64, |
| max_length: int = 512, |
| num_workers: int = 4, |
| ) -> Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader, torch.utils.data.DataLoader]: |
| """ |
| Create train, validation, and test dataloaders. |
| |
| Args: |
| csv_path: Path to CSV file |
| task: Task name |
| vocab_path: Path to vocabulary.json |
| batch_size: Batch size |
| max_length: Maximum sequence length |
| num_workers: Number of data loading workers |
| |
| Returns: |
| (train_loader, val_loader, test_loader) |
| """ |
| |
| train_dataset = GlycanClassificationDataset( |
| csv_path, task, 'train', vocab_path, max_length |
| ) |
| val_dataset = GlycanClassificationDataset( |
| csv_path, task, 'validation', vocab_path, max_length |
| ) |
| test_dataset = GlycanClassificationDataset( |
| csv_path, task, 'test', vocab_path, max_length |
| ) |
| |
| |
| train_loader = torch.utils.data.DataLoader( |
| train_dataset, |
| batch_size=batch_size, |
| shuffle=True, |
| num_workers=num_workers, |
| pin_memory=True, |
| ) |
| |
| val_loader = torch.utils.data.DataLoader( |
| val_dataset, |
| batch_size=batch_size, |
| shuffle=False, |
| num_workers=num_workers, |
| pin_memory=True, |
| ) |
| |
| test_loader = torch.utils.data.DataLoader( |
| test_dataset, |
| batch_size=batch_size, |
| shuffle=False, |
| num_workers=num_workers, |
| pin_memory=True, |
| ) |
| |
| return train_loader, val_loader, test_loader |
|
|
|
|