""" Dataset Loaders for Molecular Property Prediction Supports loading datasets from Hugging Face and PyTorch Geometric. """ import torch import numpy as np from torch.utils.data import Dataset from torch_geometric.data import Data, InMemoryDataset from datasets import load_dataset from typing import List, Tuple, Optional, Dict import os class QM9Dataset(InMemoryDataset): """ QM9 dataset loader from Hugging Face. Loads the QM9 dataset (133K molecules) with 3D coordinates and quantum properties. Args: root (str): Root directory for dataset storage hf_dataset_path (str): Hugging Face dataset path (default: "yairschiff/qm9") target_property (str): Target property to predict (default: "homo") split (str): Dataset split ("train", "val", "test") transform: Optional transform to apply pre_transform: Optional pre-transform to apply """ def __init__( self, root: str = "data/qm9", hf_dataset_path: str = "yairschiff/qm9", target_property: str = "homo", split: str = "train", transform=None, pre_transform=None, ): self.hf_dataset_path = hf_dataset_path self.target_property = target_property self.split = split super().__init__(root, transform, pre_transform) self.data, self.slices = torch.load(self.processed_paths[0], weights_only=False) @property def raw_file_names(self) -> List[str]: return [] # We load directly from HuggingFace @property def processed_file_names(self) -> List[str]: return [f'qm9_{self.split}_processed.pt'] def download(self): """Download is handled by HuggingFace datasets library.""" pass def process(self): """Process raw data into PyTorch Geometric format.""" print(f"Loading QM9 dataset from Hugging Face: {self.hf_dataset_path}") # Load dataset directly from Hugging Face raw_dataset = load_dataset(self.hf_dataset_path, split=self.split) print(f"Loaded {len(raw_dataset)} molecules for {self.split} split") data_list = [] print("Processing molecules...") for idx, sample in enumerate(raw_dataset): if idx % 10000 == 0: print(f" Processed {idx}/{len(raw_dataset)} molecules") # Extract features - use correct column names from HF dataset atomic_symbols = sample['atomic_symbols'] # List of atomic symbols like ['C', 'H', 'H', ...] positions = torch.tensor(sample['pos'], dtype=torch.float32) # 3D coordinates # Convert atomic symbols to atomic numbers symbol_to_number = { 'H': 1, 'He': 2, 'Li': 3, 'Be': 4, 'B': 5, 'C': 6, 'N': 7, 'O': 8, 'F': 9, 'Ne': 10 } atomic_numbers = [symbol_to_number.get(sym, 0) for sym in atomic_symbols] # One-hot encode atomic numbers (H=1 to Ne=10, plus one for unknown) num_atoms = len(atomic_numbers) node_features = torch.zeros(num_atoms, 11) # 11 possible atom types for i, z in enumerate(atomic_numbers): if 1 <= z <= 10: node_features[i, z - 1] = 1.0 else: node_features[i, 10] = 1.0 # Unknown atom type # Get target property target = torch.tensor([sample[self.target_property]], dtype=torch.float32) # Distance-based edges (5.0 Angstrom cutoff) edge_index = self._get_edge_index(num_atoms, cutoff=5.0, positions=positions) # Create PyG Data object data = Data( x=node_features, pos=positions, edge_index=edge_index, y=target, num_atoms=num_atoms, ) if self.pre_transform is not None: data = self.pre_transform(data) data_list.append(data) # Save processed data data, slices = self.collate(data_list) torch.save((data, slices), self.processed_paths[0]) print(f"Processed {len(data_list)} molecules") def _get_edge_index(self, num_atoms: int, cutoff: float = 5.0, positions: Optional[torch.Tensor] = None) -> torch.Tensor: """ Create edge index based on distance cutoff. Args: num_atoms: Number of atoms in molecule cutoff: Distance cutoff in Angstroms positions: Atom positions [num_atoms, 3] (optional) Returns: edge_index: Edge indices [2, num_edges] """ if positions is not None: # Distance-based edges (more realistic!) dist_matrix = torch.cdist(positions, positions) # Create edges within cutoff distance within_cutoff = (dist_matrix < cutoff) & (dist_matrix > 0) # Exclude self-loops edge_index = within_cutoff.nonzero().t() else: # Fallback: fully connected (for backward compatibility) row = [] col = [] for i in range(num_atoms): for j in range(num_atoms): if i != j: row.append(i) col.append(j) edge_index = torch.tensor([row, col], dtype=torch.long) return edge_index class MolecularDatasetWrapper: """ Wrapper for molecular datasets with preprocessing and normalization. Args: dataset: Base dataset (e.g., QM9Dataset) normalize_targets: Whether to normalize target values rotation_augmentation: Whether to apply random rotations """ def __init__( self, dataset: InMemoryDataset, normalize_targets: bool = True, rotation_augmentation: bool = False, target_mean: Optional[float] = None, target_std: Optional[float] = None, ): self.dataset = dataset self.normalize_targets = normalize_targets self.rotation_augmentation = rotation_augmentation # Use provided stats or compute from dataset if normalize_targets: if target_mean is not None and target_std is not None: # Use provided normalization stats (for val/test) self.target_mean = target_mean self.target_std = target_std else: # Compute from this dataset (for train) self.target_mean, self.target_std = self._compute_normalization_stats() else: self.target_mean = 0.0 self.target_std = 1.0 def _compute_normalization_stats(self) -> Tuple[float, float]: """Compute mean and std of target values.""" targets = [] for i in range(len(self.dataset)): data = self.dataset[i] targets.append(data.y.item()) targets = np.array(targets) return targets.mean(), targets.std() def __len__(self) -> int: return len(self.dataset) def __getitem__(self, idx: int) -> Data: data = self.dataset[idx].clone() # Normalize targets if self.normalize_targets: data.y = (data.y - self.target_mean) / self.target_std # Apply rotation augmentation (only if enabled) if self.rotation_augmentation: data.pos = self._random_rotation(data.pos) return data def _random_rotation(self, pos: torch.Tensor) -> torch.Tensor: """Apply random 3D rotation.""" # Random rotation matrix angles = torch.rand(3) * 2 * np.pi # Rotation around x-axis Rx = torch.tensor([ [1, 0, 0], [0, torch.cos(angles[0]), -torch.sin(angles[0])], [0, torch.sin(angles[0]), torch.cos(angles[0])] ], dtype=pos.dtype) # Rotation around y-axis Ry = torch.tensor([ [torch.cos(angles[1]), 0, torch.sin(angles[1])], [0, 1, 0], [-torch.sin(angles[1]), 0, torch.cos(angles[1])] ], dtype=pos.dtype) # Rotation around z-axis Rz = torch.tensor([ [torch.cos(angles[2]), -torch.sin(angles[2]), 0], [torch.sin(angles[2]), torch.cos(angles[2]), 0], [0, 0, 1] ], dtype=pos.dtype) # Combined rotation R = Rz @ Ry @ Rx return pos @ R.T def denormalize(self, normalized_values: torch.Tensor) -> torch.Tensor: """Denormalize predictions back to original scale.""" return normalized_values * self.target_std + self.target_mean def create_qm9_dataloaders( root: str = "data/qm9", hf_dataset_path: str = "yairschiff/qm9", target_property: str = "homo", batch_size: int = 32, num_workers: int = 4, normalize_targets: bool = True, rotation_augmentation: bool = True, train_split: float = 0.8, val_split: float = 0.1, ) -> Dict[str, torch.utils.data.DataLoader]: """ Create train/val/test dataloaders for QM9. Args: root: Root directory for dataset hf_dataset_path: Hugging Face dataset path target_property: Target property to predict batch_size: Batch size num_workers: Number of data loading workers normalize_targets: Whether to normalize targets rotation_augmentation: Whether to use rotation augmentation train_split: Fraction for training (default: 0.8) val_split: Fraction for validation (default: 0.1) Returns: Dictionary with 'train', 'val', 'test' dataloaders """ from torch_geometric.loader import DataLoader from torch.utils.data import random_split # Load the full dataset (only 'train' split exists in HF) print("Loading full QM9 dataset...") full_dataset = QM9Dataset( root=root, hf_dataset_path=hf_dataset_path, target_property=target_property, split="train" # HF only has 'train' split ) # Calculate split sizes total_size = len(full_dataset) train_size = int(train_split * total_size) val_size = int(val_split * total_size) test_size = total_size - train_size - val_size print(f"Splitting dataset: train={train_size}, val={val_size}, test={test_size}") # Random split train_dataset, val_dataset, test_dataset = random_split( full_dataset, [train_size, val_size, test_size], generator=torch.Generator().manual_seed(42) ) # Wrap datasets - compute normalization from TRAINING set only! train_wrapped = MolecularDatasetWrapper( train_dataset, normalize_targets=normalize_targets, rotation_augmentation=rotation_augmentation ) # Val and test use TRAINING set's normalization stats val_wrapped = MolecularDatasetWrapper( val_dataset, normalize_targets=normalize_targets, rotation_augmentation=False, # No augmentation for validation target_mean=train_wrapped.target_mean if normalize_targets else None, target_std=train_wrapped.target_std if normalize_targets else None, ) test_wrapped = MolecularDatasetWrapper( test_dataset, normalize_targets=normalize_targets, rotation_augmentation=False, # No augmentation for test target_mean=train_wrapped.target_mean if normalize_targets else None, target_std=train_wrapped.target_std if normalize_targets else None, ) # Create dataloaders train_loader = DataLoader( train_wrapped, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True ) val_loader = DataLoader( val_wrapped, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True ) test_loader = DataLoader( test_wrapped, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True ) return { 'train': train_loader, 'val': val_loader, 'test': test_loader, 'train_dataset': train_wrapped, # For denormalization }