| |
| """ |
| Adapter to use HuggingFace datasets with the any-length discrete diffusion model. |
| This module converts HuggingFace datasets (like datamol-io/safe-drugs) into the format |
| expected by the training pipeline. |
| """ |
|
|
| import torch |
| from torch.utils.data import Dataset, DataLoader |
| from datasets import load_dataset |
| import pytorch_lightning as pl |
| from safe.tokenizer import SAFETokenizer |
| from mol_utils.bracket_safe_converter import safe2bracketsafe |
| from typing import Optional, List |
| import re |
|
|
|
|
| def get_tokenizer(): |
| """Get SAFE tokenizer with added special tokens.""" |
| tk = SAFETokenizer.from_pretrained('datamol-io/safe-gpt').get_pretrained() |
| tk.add_tokens(['<', '>']) |
| return tk |
|
|
|
|
| class Collator: |
| """Data collator for SAFE/bracket-SAFE format.""" |
| |
| def __init__(self, config, tokenizer=None): |
| self.tokenizer = tokenizer if tokenizer is not None else get_tokenizer() |
| self.max_length = config.interpolant.max_length |
| self.use_bracket_safe = config.training.get('use_bracket_safe', False) |
| |
| def __call__(self, examples): |
| |
| inputs = [] |
| for example in examples: |
| if isinstance(example, dict): |
| |
| input_text = example.get('input', example.get('labels', example.get('smiles', ''))) |
| else: |
| input_text = example |
| |
| if self.use_bracket_safe: |
| input_text = safe2bracketsafe(input_text) |
| |
| inputs.append(input_text) |
| |
| batch = self.tokenizer( |
| inputs, |
| return_tensors='pt', |
| padding=True, |
| truncation=True, |
| max_length=self.max_length |
| ) |
| |
| |
| |
| result = { |
| 'input_ids': batch['input_ids'], |
| 'attention_mask': batch['attention_mask'] |
| } |
| |
| return result |
|
|
|
|
| class HFDatasetAdapter(Dataset): |
| """Adapts HuggingFace datasets to the format expected by the diffusion model.""" |
| |
| def __init__(self, hf_dataset, tokenizer, smiles_column='smiles', max_length=1024, convert_to_safe=False, is_streaming=False): |
| """ |
| Args: |
| hf_dataset: HuggingFace dataset object (streaming or regular) |
| tokenizer: SMILES tokenizer instance |
| smiles_column: Name of the column containing SMILES strings |
| max_length: Maximum sequence length |
| convert_to_safe: Whether to convert SMILES to SAFE format |
| is_streaming: Whether dataset is in streaming mode |
| """ |
| self.tokenizer = tokenizer |
| self.smiles_column = smiles_column |
| self.max_length = max_length |
| self.convert_to_safe = convert_to_safe |
| self.is_streaming = is_streaming |
| |
| if is_streaming: |
| |
| self.data = hf_dataset |
| self._length = None |
| print(f'Initialized streaming dataset adapter') |
| else: |
| |
| print(f'Initializing HF dataset adapter with {len(hf_dataset)} samples...') |
| self.data = [] |
| for item in hf_dataset: |
| smiles = item[smiles_column] |
| if smiles: |
| self.data.append({'input': smiles, 'labels': smiles}) |
| print(f'Processed {len(self.data)} valid samples') |
| |
| def __len__(self): |
| if self.is_streaming: |
| |
| |
| return 10_000_000 if self._length is None else self._length |
| return len(self.data) |
| |
| def __getitem__(self, idx): |
| if self.is_streaming: |
| |
| raise NotImplementedError("Streaming datasets should be iterated, not indexed") |
| return self.data[idx] |
| |
| def __iter__(self): |
| """Support iteration for streaming datasets.""" |
| if self.is_streaming: |
| for item in self.data: |
| smiles = item[self.smiles_column] |
| if smiles: |
| yield {'input': smiles, 'labels': smiles} |
| else: |
| for item in self.data: |
| yield item |
|
|
|
|
| class HFDataModule(pl.LightningDataModule): |
| """PyTorch Lightning DataModule for HuggingFace datasets.""" |
| |
| def __init__( |
| self, |
| config, |
| dataset_name: str, |
| tokenizer: SAFETokenizer, |
| smiles_column: str = 'smiles', |
| val_split: float = 0.1, |
| test_split: Optional[float] = None, |
| streaming: bool = True, |
| max_train_samples: Optional[int] = None, |
| max_val_samples: Optional[int] = None, |
| ): |
| """ |
| Args: |
| config: Configuration object containing training parameters |
| dataset_name: HuggingFace dataset identifier (e.g., "datamol-io/safe-gpt") |
| tokenizer: SMILES tokenizer instance |
| smiles_column: Name of column containing SMILES strings |
| val_split: Fraction of data to use for validation |
| test_split: Optional fraction of data to use for testing |
| streaming: Whether to use streaming mode (recommended for large datasets) |
| max_train_samples: Maximum number of training samples to use (for non-streaming) |
| max_val_samples: Maximum number of validation samples to use (for non-streaming) |
| """ |
| super().__init__() |
| self.config = config |
| self.dataset_name = dataset_name |
| self.tokenizer = tokenizer |
| self.smiles_column = smiles_column |
| self.max_length = config.interpolant.max_length |
| self.batch_size = config.training.per_gpu_batch_size |
| self.num_workers = config.training.get('cpus', 4) |
| self.val_split = val_split |
| self.test_split = test_split |
| self.streaming = streaming |
| self.max_train_samples = max_train_samples |
| self.max_val_samples = max_val_samples |
| |
| self.train_dataset = None |
| self.val_dataset = None |
| self.test_dataset = None |
| |
| |
| self.collator = Collator(config, tokenizer) |
| |
| def setup(self, stage: Optional[str] = None): |
| """Load and split the dataset.""" |
| print(f'Loading dataset: {self.dataset_name} (streaming={self.streaming})') |
| |
| if self.streaming: |
| |
| raw_dataset = load_dataset(self.dataset_name, streaming=True) |
| |
| |
| if 'train' in raw_dataset: |
| train_stream = raw_dataset['train'] |
| else: |
| |
| train_stream = raw_dataset[list(raw_dataset.keys())[0]] |
| |
| |
| |
| val_size = int(100000 * self.val_split) |
| train_size = 100000 - val_size |
| |
| |
| val_stream = train_stream.take(val_size) |
| |
| |
| train_stream_shifted = train_stream.skip(val_size) |
| |
| |
| self.train_dataset = HFDatasetAdapter( |
| train_stream_shifted, |
| self.tokenizer, |
| self.smiles_column, |
| self.max_length, |
| is_streaming=True |
| ) |
| |
| self.val_dataset = HFDatasetAdapter( |
| val_stream, |
| self.tokenizer, |
| self.smiles_column, |
| self.max_length, |
| is_streaming=True |
| ) |
| |
| print(f'Streaming dataset initialized - samples will be loaded on-the-fly') |
| |
| else: |
| |
| raw_dataset = load_dataset(self.dataset_name) |
| |
| |
| if 'train' in raw_dataset: |
| train_data = raw_dataset['train'] |
| else: |
| |
| train_data = raw_dataset[list(raw_dataset.keys())[0]] |
| |
| |
| if self.max_train_samples: |
| train_data = train_data.select(range(min(self.max_train_samples, len(train_data)))) |
| |
| |
| if 'validation' in raw_dataset or 'val' in raw_dataset: |
| val_key = 'validation' if 'validation' in raw_dataset else 'val' |
| val_data = raw_dataset[val_key] |
| else: |
| |
| split_dataset = train_data.train_test_split(test_size=self.val_split, seed=42) |
| train_data = split_dataset['train'] |
| val_data = split_dataset['test'] |
| |
| |
| if self.max_val_samples: |
| val_data = val_data.select(range(min(self.max_val_samples, len(val_data)))) |
| |
| |
| if self.test_split and 'test' not in raw_dataset: |
| split_dataset = train_data.train_test_split(test_size=self.test_split, seed=42) |
| train_data = split_dataset['train'] |
| self.test_dataset = HFDatasetAdapter( |
| split_dataset['test'], |
| self.tokenizer, |
| self.smiles_column, |
| self.max_length, |
| is_streaming=False |
| ) |
| elif 'test' in raw_dataset: |
| self.test_dataset = HFDatasetAdapter( |
| raw_dataset['test'], |
| self.tokenizer, |
| self.smiles_column, |
| self.max_length, |
| is_streaming=False |
| ) |
| |
| |
| self.train_dataset = HFDatasetAdapter( |
| train_data, |
| self.tokenizer, |
| self.smiles_column, |
| self.max_length, |
| is_streaming=False |
| ) |
| |
| self.val_dataset = HFDatasetAdapter( |
| val_data, |
| self.tokenizer, |
| self.smiles_column, |
| self.max_length, |
| is_streaming=False |
| ) |
| |
| print(f'Dataset splits - Train: {len(self.train_dataset)}, Val: {len(self.val_dataset)}') |
| if self.test_dataset: |
| print(f'Test: {len(self.test_dataset)}') |
| |
| def train_dataloader(self): |
| if self.streaming: |
| |
| |
| return DataLoader( |
| self.train_dataset.data, |
| batch_size=self.batch_size, |
| collate_fn=self.collator, |
| num_workers=0, |
| pin_memory=True, |
| shuffle=False, |
| ) |
| else: |
| return DataLoader( |
| self.train_dataset, |
| batch_size=self.batch_size, |
| collate_fn=self.collator, |
| shuffle=True, |
| num_workers=self.num_workers, |
| pin_memory=True, |
| persistent_workers=True if self.num_workers > 0 else False |
| ) |
| |
| def val_dataloader(self): |
| if self.streaming: |
| |
| |
| return DataLoader( |
| self.val_dataset.data, |
| batch_size=self.batch_size, |
| collate_fn=self.collator, |
| num_workers=0, |
| pin_memory=True, |
| shuffle=False, |
| ) |
| else: |
| return DataLoader( |
| self.val_dataset, |
| batch_size=self.batch_size, |
| collate_fn=self.collator, |
| shuffle=False, |
| num_workers=self.num_workers, |
| pin_memory=True, |
| persistent_workers=True if self.num_workers > 0 else False |
| ) |
| |
| def test_dataloader(self): |
| if self.test_dataset: |
| return DataLoader( |
| self.test_dataset, |
| batch_size=self.batch_size, |
| collate_fn=self.collator, |
| shuffle=False, |
| num_workers=self.num_workers, |
| pin_memory=True, |
| persistent_workers=True if self.num_workers > 0 else False |
| ) |
| return None |
|
|
|
|
| def setup_hf_data_and_update_config(config, dataset_name="datamol-io/safe-gpt", smiles_column="smiles", streaming=True): |
| """ |
| Setup HuggingFace dataset and update config with token information. |
| |
| Args: |
| config: Hydra config object |
| dataset_name: HuggingFace dataset identifier |
| smiles_column: Name of column containing SMILES strings |
| streaming: Whether to use streaming mode (recommended for large datasets like safe-gpt) |
| |
| Returns: |
| HFDataModule instance |
| """ |
| |
| tokenizer = get_tokenizer() |
| |
| |
| config.interpolant.tokens = len(tokenizer) |
| config.interpolant.pad_token = tokenizer.pad_token_id |
| config.interpolant.mask_token = tokenizer.mask_token_id |
| |
| |
| data_module = HFDataModule( |
| config=config, |
| dataset_name=dataset_name, |
| tokenizer=tokenizer, |
| smiles_column=smiles_column, |
| val_split=0.1, |
| streaming=streaming, |
| ) |
| |
| return data_module |
|
|