atlas-1-demo / src /data /datasets.py
Reverb's picture
Upload folder using huggingface_hub
8eabce6 verified
"""
Dataset Loaders for Molecular Property Prediction
Supports loading datasets from Hugging Face and PyTorch Geometric.
"""
import torch
import numpy as np
from torch.utils.data import Dataset
from torch_geometric.data import Data, InMemoryDataset
from datasets import load_dataset
from typing import List, Tuple, Optional, Dict
import os
class QM9Dataset(InMemoryDataset):
"""
QM9 dataset loader from Hugging Face.
Loads the QM9 dataset (133K molecules) with 3D coordinates and quantum properties.
Args:
root (str): Root directory for dataset storage
hf_dataset_path (str): Hugging Face dataset path (default: "yairschiff/qm9")
target_property (str): Target property to predict (default: "homo")
split (str): Dataset split ("train", "val", "test")
transform: Optional transform to apply
pre_transform: Optional pre-transform to apply
"""
def __init__(
self,
root: str = "data/qm9",
hf_dataset_path: str = "yairschiff/qm9",
target_property: str = "homo",
split: str = "train",
transform=None,
pre_transform=None,
):
self.hf_dataset_path = hf_dataset_path
self.target_property = target_property
self.split = split
super().__init__(root, transform, pre_transform)
self.data, self.slices = torch.load(self.processed_paths[0], weights_only=False)
@property
def raw_file_names(self) -> List[str]:
return [] # We load directly from HuggingFace
@property
def processed_file_names(self) -> List[str]:
return [f'qm9_{self.split}_processed.pt']
def download(self):
"""Download is handled by HuggingFace datasets library."""
pass
def process(self):
"""Process raw data into PyTorch Geometric format."""
print(f"Loading QM9 dataset from Hugging Face: {self.hf_dataset_path}")
# Load dataset directly from Hugging Face
raw_dataset = load_dataset(self.hf_dataset_path, split=self.split)
print(f"Loaded {len(raw_dataset)} molecules for {self.split} split")
data_list = []
print("Processing molecules...")
for idx, sample in enumerate(raw_dataset):
if idx % 10000 == 0:
print(f" Processed {idx}/{len(raw_dataset)} molecules")
# Extract features - use correct column names from HF dataset
atomic_symbols = sample['atomic_symbols'] # List of atomic symbols like ['C', 'H', 'H', ...]
positions = torch.tensor(sample['pos'], dtype=torch.float32) # 3D coordinates
# Convert atomic symbols to atomic numbers
symbol_to_number = {
'H': 1, 'He': 2, 'Li': 3, 'Be': 4, 'B': 5, 'C': 6,
'N': 7, 'O': 8, 'F': 9, 'Ne': 10
}
atomic_numbers = [symbol_to_number.get(sym, 0) for sym in atomic_symbols]
# One-hot encode atomic numbers (H=1 to Ne=10, plus one for unknown)
num_atoms = len(atomic_numbers)
node_features = torch.zeros(num_atoms, 11) # 11 possible atom types
for i, z in enumerate(atomic_numbers):
if 1 <= z <= 10:
node_features[i, z - 1] = 1.0
else:
node_features[i, 10] = 1.0 # Unknown atom type
# Get target property
target = torch.tensor([sample[self.target_property]], dtype=torch.float32)
# Distance-based edges (5.0 Angstrom cutoff)
edge_index = self._get_edge_index(num_atoms, cutoff=5.0, positions=positions)
# Create PyG Data object
data = Data(
x=node_features,
pos=positions,
edge_index=edge_index,
y=target,
num_atoms=num_atoms,
)
if self.pre_transform is not None:
data = self.pre_transform(data)
data_list.append(data)
# Save processed data
data, slices = self.collate(data_list)
torch.save((data, slices), self.processed_paths[0])
print(f"Processed {len(data_list)} molecules")
def _get_edge_index(self, num_atoms: int, cutoff: float = 5.0, positions: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Create edge index based on distance cutoff.
Args:
num_atoms: Number of atoms in molecule
cutoff: Distance cutoff in Angstroms
positions: Atom positions [num_atoms, 3] (optional)
Returns:
edge_index: Edge indices [2, num_edges]
"""
if positions is not None:
# Distance-based edges (more realistic!)
dist_matrix = torch.cdist(positions, positions)
# Create edges within cutoff distance
within_cutoff = (dist_matrix < cutoff) & (dist_matrix > 0) # Exclude self-loops
edge_index = within_cutoff.nonzero().t()
else:
# Fallback: fully connected (for backward compatibility)
row = []
col = []
for i in range(num_atoms):
for j in range(num_atoms):
if i != j:
row.append(i)
col.append(j)
edge_index = torch.tensor([row, col], dtype=torch.long)
return edge_index
class MolecularDatasetWrapper:
"""
Wrapper for molecular datasets with preprocessing and normalization.
Args:
dataset: Base dataset (e.g., QM9Dataset)
normalize_targets: Whether to normalize target values
rotation_augmentation: Whether to apply random rotations
"""
def __init__(
self,
dataset: InMemoryDataset,
normalize_targets: bool = True,
rotation_augmentation: bool = False,
target_mean: Optional[float] = None,
target_std: Optional[float] = None,
):
self.dataset = dataset
self.normalize_targets = normalize_targets
self.rotation_augmentation = rotation_augmentation
# Use provided stats or compute from dataset
if normalize_targets:
if target_mean is not None and target_std is not None:
# Use provided normalization stats (for val/test)
self.target_mean = target_mean
self.target_std = target_std
else:
# Compute from this dataset (for train)
self.target_mean, self.target_std = self._compute_normalization_stats()
else:
self.target_mean = 0.0
self.target_std = 1.0
def _compute_normalization_stats(self) -> Tuple[float, float]:
"""Compute mean and std of target values."""
targets = []
for i in range(len(self.dataset)):
data = self.dataset[i]
targets.append(data.y.item())
targets = np.array(targets)
return targets.mean(), targets.std()
def __len__(self) -> int:
return len(self.dataset)
def __getitem__(self, idx: int) -> Data:
data = self.dataset[idx].clone()
# Normalize targets
if self.normalize_targets:
data.y = (data.y - self.target_mean) / self.target_std
# Apply rotation augmentation (only if enabled)
if self.rotation_augmentation:
data.pos = self._random_rotation(data.pos)
return data
def _random_rotation(self, pos: torch.Tensor) -> torch.Tensor:
"""Apply random 3D rotation."""
# Random rotation matrix
angles = torch.rand(3) * 2 * np.pi
# Rotation around x-axis
Rx = torch.tensor([
[1, 0, 0],
[0, torch.cos(angles[0]), -torch.sin(angles[0])],
[0, torch.sin(angles[0]), torch.cos(angles[0])]
], dtype=pos.dtype)
# Rotation around y-axis
Ry = torch.tensor([
[torch.cos(angles[1]), 0, torch.sin(angles[1])],
[0, 1, 0],
[-torch.sin(angles[1]), 0, torch.cos(angles[1])]
], dtype=pos.dtype)
# Rotation around z-axis
Rz = torch.tensor([
[torch.cos(angles[2]), -torch.sin(angles[2]), 0],
[torch.sin(angles[2]), torch.cos(angles[2]), 0],
[0, 0, 1]
], dtype=pos.dtype)
# Combined rotation
R = Rz @ Ry @ Rx
return pos @ R.T
def denormalize(self, normalized_values: torch.Tensor) -> torch.Tensor:
"""Denormalize predictions back to original scale."""
return normalized_values * self.target_std + self.target_mean
def create_qm9_dataloaders(
root: str = "data/qm9",
hf_dataset_path: str = "yairschiff/qm9",
target_property: str = "homo",
batch_size: int = 32,
num_workers: int = 4,
normalize_targets: bool = True,
rotation_augmentation: bool = True,
train_split: float = 0.8,
val_split: float = 0.1,
) -> Dict[str, torch.utils.data.DataLoader]:
"""
Create train/val/test dataloaders for QM9.
Args:
root: Root directory for dataset
hf_dataset_path: Hugging Face dataset path
target_property: Target property to predict
batch_size: Batch size
num_workers: Number of data loading workers
normalize_targets: Whether to normalize targets
rotation_augmentation: Whether to use rotation augmentation
train_split: Fraction for training (default: 0.8)
val_split: Fraction for validation (default: 0.1)
Returns:
Dictionary with 'train', 'val', 'test' dataloaders
"""
from torch_geometric.loader import DataLoader
from torch.utils.data import random_split
# Load the full dataset (only 'train' split exists in HF)
print("Loading full QM9 dataset...")
full_dataset = QM9Dataset(
root=root,
hf_dataset_path=hf_dataset_path,
target_property=target_property,
split="train" # HF only has 'train' split
)
# Calculate split sizes
total_size = len(full_dataset)
train_size = int(train_split * total_size)
val_size = int(val_split * total_size)
test_size = total_size - train_size - val_size
print(f"Splitting dataset: train={train_size}, val={val_size}, test={test_size}")
# Random split
train_dataset, val_dataset, test_dataset = random_split(
full_dataset,
[train_size, val_size, test_size],
generator=torch.Generator().manual_seed(42)
)
# Wrap datasets - compute normalization from TRAINING set only!
train_wrapped = MolecularDatasetWrapper(
train_dataset,
normalize_targets=normalize_targets,
rotation_augmentation=rotation_augmentation
)
# Val and test use TRAINING set's normalization stats
val_wrapped = MolecularDatasetWrapper(
val_dataset,
normalize_targets=normalize_targets,
rotation_augmentation=False, # No augmentation for validation
target_mean=train_wrapped.target_mean if normalize_targets else None,
target_std=train_wrapped.target_std if normalize_targets else None,
)
test_wrapped = MolecularDatasetWrapper(
test_dataset,
normalize_targets=normalize_targets,
rotation_augmentation=False, # No augmentation for test
target_mean=train_wrapped.target_mean if normalize_targets else None,
target_std=train_wrapped.target_std if normalize_targets else None,
)
# Create dataloaders
train_loader = DataLoader(
train_wrapped,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True
)
val_loader = DataLoader(
val_wrapped,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True
)
test_loader = DataLoader(
test_wrapped,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True
)
return {
'train': train_loader,
'val': val_loader,
'test': test_loader,
'train_dataset': train_wrapped, # For denormalization
}