supanthadey1's picture
Add BERTose and AFFINose training code release
1d6f391 verified
Raw
History Blame Contribute Delete
12.4 kB
"""
Simplified Dataset for Glycan Classification Fine-tuning
Works with clean benchmark CSV that has WURCS column.
Supports both atomic (v3) and BPE (v4) tokenization.
"""
import torch
from torch.utils.data import Dataset
import pandas as pd
from typing import Dict, List, Optional, Tuple
import logging
import json
from .tokenizer import WURCSTokenizer, create_tokenizer
# Try to import BPE tokenizer (available in v4)
try:
from .wurcs_bpe_tokenizer import WURCSBPETokenizer
HAS_BPE = True
except ImportError:
HAS_BPE = False
logger = logging.getLogger(__name__)
def load_tokenizer(vocab_path: str):
"""
Load tokenizer based on vocabulary file type.
Args:
vocab_path: Path to vocabulary JSON file
Returns:
Either WURCSTokenizer (atomic) or WURCSBPETokenizer (BPE)
"""
with open(vocab_path, 'r') as f:
vocab = json.load(f)
# Check if this is a BPE vocabulary (has 'merges' field)
if 'merges' in vocab:
if not HAS_BPE:
raise ImportError("BPE vocabulary detected but WURCSBPETokenizer not available")
logger.info(f"Loading BPE tokenizer (vocab_size={vocab['metadata']['vocab_size']})")
return WURCSBPETokenizer(vocab_path)
else:
logger.info(f"Loading atomic tokenizer (vocab_size={vocab.get('metadata', {}).get('vocab_size', 167)})")
return WURCSTokenizer(vocab_path)
class GlycanClassificationDataset(Dataset):
"""
Dataset for glycan classification tasks.
Expects a CSV with columns:
- target: IUPAC representation of glycan
- wurcs: WURCS representation (required)
- {task_name}: Label column (e.g., 'species', 'phylum')
- split: 'train', 'validation', or 'test'
"""
def __init__(
self,
csv_path: str,
task: str,
split: str,
vocab_path: str,
max_length: int = 256, # Reduced default for BPE (was 512)
valid_classes: List[str] = None,
):
"""
Initialize dataset.
Args:
csv_path: Path to CSV file
task: Task name (column name for labels)
split: One of 'train', 'validation', 'test'
vocab_path: Path to vocabulary.json
max_length: Maximum sequence length
valid_classes: Optional list of valid classes to filter to
"""
self.task = task
self.split = split
self.max_length = max_length
# Load tokenizer (auto-detects BPE vs atomic based on vocab file)
self.tokenizer = load_tokenizer(vocab_path)
# Load data
df = pd.read_csv(csv_path)
# Filter by split
# Handle different column naming conventions:
# - Classification CSV: 'train', 'validation', 'test' columns (binary 0/1)
# - Immunogenicity/Link CSV: 'train', 'valid', 'test' columns (binary 0/1)
# - Some files might have a single 'split' column with string values
# Map 'validation' to 'valid' if needed
split_col = split
if split == 'validation' and 'validation' not in df.columns and 'valid' in df.columns:
split_col = 'valid'
if split_col in df.columns:
# Binary column for this split
if df[split_col].dtype == 'bool':
self.df = df[df[split_col]].copy()
else: # int64 or similar
self.df = df[df[split_col] == 1].copy()
elif 'split' in df.columns:
# Single 'split' column with string values
self.df = df[df['split'] == split].copy()
else:
raise ValueError(f"Cannot find split column '{split}' or '{split_col}' or 'split' in CSV")
# Filter to only samples with WURCS
initial_count = len(self.df)
self.df = self.df[self.df['wurcs'].notna()].copy()
removed = initial_count - len(self.df)
if removed > 0:
logger.info(f"Removed {removed} samples without WURCS")
# Get unique labels and create mapping
self.df = self.df[self.df[task].notna()].copy()
# Apply valid_classes filter if provided (for strict filtering mode)
if valid_classes is not None:
before_filter = len(self.df)
self.df = self.df[self.df[task].isin(valid_classes)].copy()
filtered_out = before_filter - len(self.df)
if filtered_out > 0:
logger.info(f" Filtered out {filtered_out} samples (classes not in all splits)")
self.unique_labels = sorted([c for c in valid_classes if c in self.df[task].values])
else:
self.unique_labels = sorted(self.df[task].unique())
self.label_to_id = {label: i for i, label in enumerate(self.unique_labels)}
self.id_to_label = {i: label for label, i in self.label_to_id.items()}
# Log info
logger.info(f"Loaded {len(self.df)} samples for {split} split")
logger.info(f" Task: {task}")
logger.info(f" Classes: {len(self.unique_labels)}")
def __len__(self) -> int:
return len(self.df)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
row = self.df.iloc[idx]
# Tokenize WURCS
wurcs = row['wurcs']
tokenized = self.tokenizer.tokenize(wurcs, max_length=self.max_length)
# Get label
label = self.label_to_id[row[self.task]]
# Handle fields that may not exist in BPE tokenizer
num_residues = tokenized.get('num_residues', max(tokenized.get('residue_ids', [0])) + 1)
is_branched = tokenized.get('is_branched', '[BRANCH_OPEN]' in tokenized.get('tokens', []))
return {
'token_ids': torch.tensor(tokenized['token_ids'], dtype=torch.long),
'attention_mask': torch.tensor(tokenized['attention_mask'], dtype=torch.long),
'residue_ids': torch.tensor(tokenized['residue_ids'], dtype=torch.long),
'branch_depths': torch.tensor(tokenized['branch_depths'], dtype=torch.long),
'linkage_types': torch.tensor(tokenized['linkage_types'], dtype=torch.long),
'num_residues': num_residues,
'is_branched': is_branched,
'label': torch.tensor(label, dtype=torch.long),
}
def get_class_weights(self) -> torch.Tensor:
"""
Compute class weights for imbalanced data.
Returns:
Tensor of class weights (inverse frequency)
"""
class_counts = self.df[self.task].value_counts()
total = len(self.df)
n_classes = len(self.unique_labels)
weights = []
for label in self.unique_labels:
count = class_counts.get(label, 1)
weight = total / (n_classes * count)
weights.append(weight)
return torch.tensor(weights, dtype=torch.float)
def compute_valid_classes(
csv_path: str,
task: str,
min_samples: int = 1,
) -> List[str]:
"""
Compute classes that are present in all splits (train, val, test)
with at least min_samples in each split.
This is used for 'strict' filtering mode (GlycanML approach).
Args:
csv_path: Path to CSV file
task: Task column name
min_samples: Minimum samples per class per split (default 1)
Returns:
List of valid class names
"""
df = pd.read_csv(csv_path)
# Only consider samples with WURCS
df = df[df['wurcs'].notna()]
df = df[df[task].notna()]
# Get split DataFrames
# Handle different column naming conventions:
# - Classification CSV: 'train', 'validation', 'test'
# - Immunogenicity/Link CSV: 'train', 'valid', 'test'
if 'train' in df.columns:
train_df = df[df['train'] == 1]
else:
train_df = df[df['split'] == 'train']
if 'validation' in df.columns:
val_df = df[df['validation'] == 1]
elif 'valid' in df.columns:
val_df = df[df['valid'] == 1]
else:
val_df = df[df['split'] == 'validation']
if 'test' in df.columns:
test_df = df[df['test'] == 1]
else:
test_df = df[df['split'] == 'test']
# Get class counts per split
train_counts = train_df[task].value_counts()
val_counts = val_df[task].value_counts()
test_counts = test_df[task].value_counts()
# Get classes with >= min_samples in each split
train_classes = set(train_counts[train_counts >= min_samples].index)
val_classes = set(val_counts[val_counts >= min_samples].index)
test_classes = set(test_counts[test_counts >= min_samples].index)
# Classes must meet min_samples threshold in all splits
valid_classes = sorted(train_classes & val_classes & test_classes)
all_classes = set(train_counts.index) | set(val_counts.index) | set(test_counts.index)
logger.info(f"Computing valid classes for {task} (min_samples={min_samples}):")
logger.info(f" Train classes (>={min_samples}): {len(train_classes)}")
logger.info(f" Val classes (>={min_samples}): {len(val_classes)}")
logger.info(f" Test classes (>={min_samples}): {len(test_classes)}")
logger.info(f" Valid (in all splits): {len(valid_classes)}")
logger.info(f" Excluded: {len(all_classes) - len(valid_classes)}")
return valid_classes
def filter_to_valid_classes(
train_df: pd.DataFrame,
val_df: pd.DataFrame,
test_df: pd.DataFrame,
task: str,
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, List[str]]:
"""
Filter datasets to only include classes present in all splits.
Args:
train_df: Training data
val_df: Validation data
test_df: Test data
task: Task column name
Returns:
Filtered (train_df, val_df, test_df, valid_classes)
"""
train_classes = set(train_df[task].dropna().unique())
val_classes = set(val_df[task].dropna().unique())
test_classes = set(test_df[task].dropna().unique())
# Classes must be in all splits
valid_classes = sorted(train_classes & val_classes & test_classes)
logger.info(f"Filtering classes for {task}:")
logger.info(f" Train classes: {len(train_classes)}")
logger.info(f" Val classes: {len(val_classes)}")
logger.info(f" Test classes: {len(test_classes)}")
logger.info(f" Valid (in all): {len(valid_classes)}")
train_df = train_df[train_df[task].isin(valid_classes)].copy()
val_df = val_df[val_df[task].isin(valid_classes)].copy()
test_df = test_df[test_df[task].isin(valid_classes)].copy()
return train_df, val_df, test_df, valid_classes
def create_dataloaders(
csv_path: str,
task: str,
vocab_path: str,
batch_size: int = 64,
max_length: int = 512,
num_workers: int = 4,
) -> Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader, torch.utils.data.DataLoader]:
"""
Create train, validation, and test dataloaders.
Args:
csv_path: Path to CSV file
task: Task name
vocab_path: Path to vocabulary.json
batch_size: Batch size
max_length: Maximum sequence length
num_workers: Number of data loading workers
Returns:
(train_loader, val_loader, test_loader)
"""
# Create datasets
train_dataset = GlycanClassificationDataset(
csv_path, task, 'train', vocab_path, max_length
)
val_dataset = GlycanClassificationDataset(
csv_path, task, 'validation', vocab_path, max_length
)
test_dataset = GlycanClassificationDataset(
csv_path, task, 'test', vocab_path, max_length
)
# Create dataloaders
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True,
)
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True,
)
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True,
)
return train_loader, val_loader, test_loader