A2D2 / a2d2_mol /mol_dataset.py
Sophia
initial commit
8019be0
Raw
History Blame Contribute Delete
14.9 kB
#!/usr/bin/env python
"""
Adapter to use HuggingFace datasets with the any-length discrete diffusion model.
This module converts HuggingFace datasets (like datamol-io/safe-drugs) into the format
expected by the training pipeline.
"""
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
import pytorch_lightning as pl
from safe.tokenizer import SAFETokenizer
from mol_utils.bracket_safe_converter import safe2bracketsafe
from typing import Optional, List
import re
def get_tokenizer():
"""Get SAFE tokenizer with added special tokens."""
tk = SAFETokenizer.from_pretrained('datamol-io/safe-gpt').get_pretrained()
tk.add_tokens(['<', '>']) # for bracket_safe
return tk
class Collator:
"""Data collator for SAFE/bracket-SAFE format."""
def __init__(self, config, tokenizer=None):
self.tokenizer = tokenizer if tokenizer is not None else get_tokenizer()
self.max_length = config.interpolant.max_length
self.use_bracket_safe = config.training.get('use_bracket_safe', False)
def __call__(self, examples):
# Handle both dict with 'labels' and direct string format
inputs = []
for example in examples:
if isinstance(example, dict):
# Try different key names: 'input', 'labels', 'smiles'
input_text = example.get('input', example.get('labels', example.get('smiles', '')))
else:
input_text = example
if self.use_bracket_safe:
input_text = safe2bracketsafe(input_text)
inputs.append(input_text)
batch = self.tokenizer(
inputs,
return_tensors='pt',
padding=True,
truncation=True,
max_length=self.max_length
)
# Convert BatchEncoding to plain dict with tensors
# Remove token_type_ids if present (not needed for diffusion models)
result = {
'input_ids': batch['input_ids'],
'attention_mask': batch['attention_mask']
}
return result
class HFDatasetAdapter(Dataset):
"""Adapts HuggingFace datasets to the format expected by the diffusion model."""
def __init__(self, hf_dataset, tokenizer, smiles_column='smiles', max_length=1024, convert_to_safe=False, is_streaming=False):
"""
Args:
hf_dataset: HuggingFace dataset object (streaming or regular)
tokenizer: SMILES tokenizer instance
smiles_column: Name of the column containing SMILES strings
max_length: Maximum sequence length
convert_to_safe: Whether to convert SMILES to SAFE format
is_streaming: Whether dataset is in streaming mode
"""
self.tokenizer = tokenizer
self.smiles_column = smiles_column
self.max_length = max_length
self.convert_to_safe = convert_to_safe
self.is_streaming = is_streaming
if is_streaming:
# For streaming datasets, we don't pre-load the data
self.data = hf_dataset
self._length = None # Unknown length for streaming
print(f'Initialized streaming dataset adapter')
else:
# Store raw data without pre-tokenization (tokenization will happen in collator)
print(f'Initializing HF dataset adapter with {len(hf_dataset)} samples...')
self.data = []
for item in hf_dataset:
smiles = item[smiles_column]
if smiles: # Skip empty SMILES
self.data.append({'input': smiles, 'labels': smiles})
print(f'Processed {len(self.data)} valid samples')
def __len__(self):
if self.is_streaming:
# Streaming datasets don't have a length
# Return a large number to prevent issues with samplers
return 10_000_000 if self._length is None else self._length
return len(self.data)
def __getitem__(self, idx):
if self.is_streaming:
# For streaming, iteration happens differently
raise NotImplementedError("Streaming datasets should be iterated, not indexed")
return self.data[idx]
def __iter__(self):
"""Support iteration for streaming datasets."""
if self.is_streaming:
for item in self.data:
smiles = item[self.smiles_column]
if smiles: # Skip empty SMILES
yield {'input': smiles, 'labels': smiles}
else:
for item in self.data:
yield item
class HFDataModule(pl.LightningDataModule):
"""PyTorch Lightning DataModule for HuggingFace datasets."""
def __init__(
self,
config,
dataset_name: str,
tokenizer: SAFETokenizer,
smiles_column: str = 'smiles',
val_split: float = 0.1,
test_split: Optional[float] = None,
streaming: bool = True,
max_train_samples: Optional[int] = None,
max_val_samples: Optional[int] = None,
):
"""
Args:
config: Configuration object containing training parameters
dataset_name: HuggingFace dataset identifier (e.g., "datamol-io/safe-gpt")
tokenizer: SMILES tokenizer instance
smiles_column: Name of column containing SMILES strings
val_split: Fraction of data to use for validation
test_split: Optional fraction of data to use for testing
streaming: Whether to use streaming mode (recommended for large datasets)
max_train_samples: Maximum number of training samples to use (for non-streaming)
max_val_samples: Maximum number of validation samples to use (for non-streaming)
"""
super().__init__()
self.config = config
self.dataset_name = dataset_name
self.tokenizer = tokenizer
self.smiles_column = smiles_column
self.max_length = config.interpolant.max_length
self.batch_size = config.training.per_gpu_batch_size
self.num_workers = config.training.get('cpus', 4)
self.val_split = val_split
self.test_split = test_split
self.streaming = streaming
self.max_train_samples = max_train_samples
self.max_val_samples = max_val_samples
self.train_dataset = None
self.val_dataset = None
self.test_dataset = None
# Initialize collator
self.collator = Collator(config, tokenizer)
def setup(self, stage: Optional[str] = None):
"""Load and split the dataset."""
print(f'Loading dataset: {self.dataset_name} (streaming={self.streaming})')
if self.streaming:
# Load dataset in streaming mode
raw_dataset = load_dataset(self.dataset_name, streaming=True)
# Handle different dataset structures
if 'train' in raw_dataset:
train_stream = raw_dataset['train']
else:
# If no splits exist, use the entire dataset
train_stream = raw_dataset[list(raw_dataset.keys())[0]]
# For streaming, we need to manually split train/val
# Skip validation samples, then take training samples
val_size = int(100000 * self.val_split) # Assume ~100k samples for val split calculation
train_size = 100000 - val_size
# Create validation stream (take first val_size samples)
val_stream = train_stream.take(val_size)
# Create training stream (skip val_size samples, then iterate)
train_stream_shifted = train_stream.skip(val_size)
# Create adapted datasets
self.train_dataset = HFDatasetAdapter(
train_stream_shifted,
self.tokenizer,
self.smiles_column,
self.max_length,
is_streaming=True
)
self.val_dataset = HFDatasetAdapter(
val_stream,
self.tokenizer,
self.smiles_column,
self.max_length,
is_streaming=True
)
print(f'Streaming dataset initialized - samples will be loaded on-the-fly')
else:
# Traditional non-streaming mode with full dataset loading
raw_dataset = load_dataset(self.dataset_name)
# Handle different dataset structures
if 'train' in raw_dataset:
train_data = raw_dataset['train']
else:
# If no splits exist, use the entire dataset and split it
train_data = raw_dataset[list(raw_dataset.keys())[0]]
# Limit samples if specified
if self.max_train_samples:
train_data = train_data.select(range(min(self.max_train_samples, len(train_data))))
# Check if dataset already has validation split
if 'validation' in raw_dataset or 'val' in raw_dataset:
val_key = 'validation' if 'validation' in raw_dataset else 'val'
val_data = raw_dataset[val_key]
else:
# Create train/val split
split_dataset = train_data.train_test_split(test_size=self.val_split, seed=42)
train_data = split_dataset['train']
val_data = split_dataset['test']
# Limit validation samples if specified
if self.max_val_samples:
val_data = val_data.select(range(min(self.max_val_samples, len(val_data))))
# Create test split if requested
if self.test_split and 'test' not in raw_dataset:
split_dataset = train_data.train_test_split(test_size=self.test_split, seed=42)
train_data = split_dataset['train']
self.test_dataset = HFDatasetAdapter(
split_dataset['test'],
self.tokenizer,
self.smiles_column,
self.max_length,
is_streaming=False
)
elif 'test' in raw_dataset:
self.test_dataset = HFDatasetAdapter(
raw_dataset['test'],
self.tokenizer,
self.smiles_column,
self.max_length,
is_streaming=False
)
# Create adapted datasets
self.train_dataset = HFDatasetAdapter(
train_data,
self.tokenizer,
self.smiles_column,
self.max_length,
is_streaming=False
)
self.val_dataset = HFDatasetAdapter(
val_data,
self.tokenizer,
self.smiles_column,
self.max_length,
is_streaming=False
)
print(f'Dataset splits - Train: {len(self.train_dataset)}, Val: {len(self.val_dataset)}')
if self.test_dataset:
print(f'Test: {len(self.test_dataset)}')
def train_dataloader(self):
if self.streaming:
# Pass streaming dataset directly to DataLoader (HF IterableDataset)
# Must use num_workers=0 when using .skip() or .take() operations
return DataLoader(
self.train_dataset.data, # Use the raw HF streaming dataset
batch_size=self.batch_size,
collate_fn=self.collator,
num_workers=0, # Required for streaming with skip/take operations
pin_memory=True,
shuffle=False, # Cannot shuffle streaming datasets
)
else:
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
collate_fn=self.collator,
shuffle=True,
num_workers=self.num_workers,
pin_memory=True,
persistent_workers=True if self.num_workers > 0 else False
)
def val_dataloader(self):
if self.streaming:
# Pass streaming dataset directly to DataLoader (HF IterableDataset)
# Must use num_workers=0 when using .skip() or .take() operations
return DataLoader(
self.val_dataset.data, # Use the raw HF streaming dataset
batch_size=self.batch_size,
collate_fn=self.collator,
num_workers=0, # Required for streaming with skip/take operations
pin_memory=True,
shuffle=False, # Cannot shuffle streaming datasets
)
else:
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
collate_fn=self.collator,
shuffle=False,
num_workers=self.num_workers,
pin_memory=True,
persistent_workers=True if self.num_workers > 0 else False
)
def test_dataloader(self):
if self.test_dataset:
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
collate_fn=self.collator,
shuffle=False,
num_workers=self.num_workers,
pin_memory=True,
persistent_workers=True if self.num_workers > 0 else False
)
return None
def setup_hf_data_and_update_config(config, dataset_name="datamol-io/safe-gpt", smiles_column="smiles", streaming=True):
"""
Setup HuggingFace dataset and update config with token information.
Args:
config: Hydra config object
dataset_name: HuggingFace dataset identifier
smiles_column: Name of column containing SMILES strings
streaming: Whether to use streaming mode (recommended for large datasets like safe-gpt)
Returns:
HFDataModule instance
"""
# Initialize tokenizer
tokenizer = get_tokenizer()
# Update config with tokenizer info
config.interpolant.tokens = len(tokenizer)
config.interpolant.pad_token = tokenizer.pad_token_id
config.interpolant.mask_token = tokenizer.mask_token_id
# Create data module
data_module = HFDataModule(
config=config,
dataset_name=dataset_name,
tokenizer=tokenizer,
smiles_column=smiles_column,
val_split=0.1,
streaming=streaming,
)
return data_module