Spaces:
Sleeping
Sleeping
| """ | |
| Dataset Loaders for Molecular Property Prediction | |
| Supports loading datasets from Hugging Face and PyTorch Geometric. | |
| """ | |
| import torch | |
| import numpy as np | |
| from torch.utils.data import Dataset | |
| from torch_geometric.data import Data, InMemoryDataset | |
| from datasets import load_dataset | |
| from typing import List, Tuple, Optional, Dict | |
| import os | |
| class QM9Dataset(InMemoryDataset): | |
| """ | |
| QM9 dataset loader from Hugging Face. | |
| Loads the QM9 dataset (133K molecules) with 3D coordinates and quantum properties. | |
| Args: | |
| root (str): Root directory for dataset storage | |
| hf_dataset_path (str): Hugging Face dataset path (default: "yairschiff/qm9") | |
| target_property (str): Target property to predict (default: "homo") | |
| split (str): Dataset split ("train", "val", "test") | |
| transform: Optional transform to apply | |
| pre_transform: Optional pre-transform to apply | |
| """ | |
| def __init__( | |
| self, | |
| root: str = "data/qm9", | |
| hf_dataset_path: str = "yairschiff/qm9", | |
| target_property: str = "homo", | |
| split: str = "train", | |
| transform=None, | |
| pre_transform=None, | |
| ): | |
| self.hf_dataset_path = hf_dataset_path | |
| self.target_property = target_property | |
| self.split = split | |
| super().__init__(root, transform, pre_transform) | |
| self.data, self.slices = torch.load(self.processed_paths[0], weights_only=False) | |
| def raw_file_names(self) -> List[str]: | |
| return [] # We load directly from HuggingFace | |
| def processed_file_names(self) -> List[str]: | |
| return [f'qm9_{self.split}_processed.pt'] | |
| def download(self): | |
| """Download is handled by HuggingFace datasets library.""" | |
| pass | |
| def process(self): | |
| """Process raw data into PyTorch Geometric format.""" | |
| print(f"Loading QM9 dataset from Hugging Face: {self.hf_dataset_path}") | |
| # Load dataset directly from Hugging Face | |
| raw_dataset = load_dataset(self.hf_dataset_path, split=self.split) | |
| print(f"Loaded {len(raw_dataset)} molecules for {self.split} split") | |
| data_list = [] | |
| print("Processing molecules...") | |
| for idx, sample in enumerate(raw_dataset): | |
| if idx % 10000 == 0: | |
| print(f" Processed {idx}/{len(raw_dataset)} molecules") | |
| # Extract features - use correct column names from HF dataset | |
| atomic_symbols = sample['atomic_symbols'] # List of atomic symbols like ['C', 'H', 'H', ...] | |
| positions = torch.tensor(sample['pos'], dtype=torch.float32) # 3D coordinates | |
| # Convert atomic symbols to atomic numbers | |
| symbol_to_number = { | |
| 'H': 1, 'He': 2, 'Li': 3, 'Be': 4, 'B': 5, 'C': 6, | |
| 'N': 7, 'O': 8, 'F': 9, 'Ne': 10 | |
| } | |
| atomic_numbers = [symbol_to_number.get(sym, 0) for sym in atomic_symbols] | |
| # One-hot encode atomic numbers (H=1 to Ne=10, plus one for unknown) | |
| num_atoms = len(atomic_numbers) | |
| node_features = torch.zeros(num_atoms, 11) # 11 possible atom types | |
| for i, z in enumerate(atomic_numbers): | |
| if 1 <= z <= 10: | |
| node_features[i, z - 1] = 1.0 | |
| else: | |
| node_features[i, 10] = 1.0 # Unknown atom type | |
| # Get target property | |
| target = torch.tensor([sample[self.target_property]], dtype=torch.float32) | |
| # Distance-based edges (5.0 Angstrom cutoff) | |
| edge_index = self._get_edge_index(num_atoms, cutoff=5.0, positions=positions) | |
| # Create PyG Data object | |
| data = Data( | |
| x=node_features, | |
| pos=positions, | |
| edge_index=edge_index, | |
| y=target, | |
| num_atoms=num_atoms, | |
| ) | |
| if self.pre_transform is not None: | |
| data = self.pre_transform(data) | |
| data_list.append(data) | |
| # Save processed data | |
| data, slices = self.collate(data_list) | |
| torch.save((data, slices), self.processed_paths[0]) | |
| print(f"Processed {len(data_list)} molecules") | |
| def _get_edge_index(self, num_atoms: int, cutoff: float = 5.0, positions: Optional[torch.Tensor] = None) -> torch.Tensor: | |
| """ | |
| Create edge index based on distance cutoff. | |
| Args: | |
| num_atoms: Number of atoms in molecule | |
| cutoff: Distance cutoff in Angstroms | |
| positions: Atom positions [num_atoms, 3] (optional) | |
| Returns: | |
| edge_index: Edge indices [2, num_edges] | |
| """ | |
| if positions is not None: | |
| # Distance-based edges (more realistic!) | |
| dist_matrix = torch.cdist(positions, positions) | |
| # Create edges within cutoff distance | |
| within_cutoff = (dist_matrix < cutoff) & (dist_matrix > 0) # Exclude self-loops | |
| edge_index = within_cutoff.nonzero().t() | |
| else: | |
| # Fallback: fully connected (for backward compatibility) | |
| row = [] | |
| col = [] | |
| for i in range(num_atoms): | |
| for j in range(num_atoms): | |
| if i != j: | |
| row.append(i) | |
| col.append(j) | |
| edge_index = torch.tensor([row, col], dtype=torch.long) | |
| return edge_index | |
| class MolecularDatasetWrapper: | |
| """ | |
| Wrapper for molecular datasets with preprocessing and normalization. | |
| Args: | |
| dataset: Base dataset (e.g., QM9Dataset) | |
| normalize_targets: Whether to normalize target values | |
| rotation_augmentation: Whether to apply random rotations | |
| """ | |
| def __init__( | |
| self, | |
| dataset: InMemoryDataset, | |
| normalize_targets: bool = True, | |
| rotation_augmentation: bool = False, | |
| target_mean: Optional[float] = None, | |
| target_std: Optional[float] = None, | |
| ): | |
| self.dataset = dataset | |
| self.normalize_targets = normalize_targets | |
| self.rotation_augmentation = rotation_augmentation | |
| # Use provided stats or compute from dataset | |
| if normalize_targets: | |
| if target_mean is not None and target_std is not None: | |
| # Use provided normalization stats (for val/test) | |
| self.target_mean = target_mean | |
| self.target_std = target_std | |
| else: | |
| # Compute from this dataset (for train) | |
| self.target_mean, self.target_std = self._compute_normalization_stats() | |
| else: | |
| self.target_mean = 0.0 | |
| self.target_std = 1.0 | |
| def _compute_normalization_stats(self) -> Tuple[float, float]: | |
| """Compute mean and std of target values.""" | |
| targets = [] | |
| for i in range(len(self.dataset)): | |
| data = self.dataset[i] | |
| targets.append(data.y.item()) | |
| targets = np.array(targets) | |
| return targets.mean(), targets.std() | |
| def __len__(self) -> int: | |
| return len(self.dataset) | |
| def __getitem__(self, idx: int) -> Data: | |
| data = self.dataset[idx].clone() | |
| # Normalize targets | |
| if self.normalize_targets: | |
| data.y = (data.y - self.target_mean) / self.target_std | |
| # Apply rotation augmentation (only if enabled) | |
| if self.rotation_augmentation: | |
| data.pos = self._random_rotation(data.pos) | |
| return data | |
| def _random_rotation(self, pos: torch.Tensor) -> torch.Tensor: | |
| """Apply random 3D rotation.""" | |
| # Random rotation matrix | |
| angles = torch.rand(3) * 2 * np.pi | |
| # Rotation around x-axis | |
| Rx = torch.tensor([ | |
| [1, 0, 0], | |
| [0, torch.cos(angles[0]), -torch.sin(angles[0])], | |
| [0, torch.sin(angles[0]), torch.cos(angles[0])] | |
| ], dtype=pos.dtype) | |
| # Rotation around y-axis | |
| Ry = torch.tensor([ | |
| [torch.cos(angles[1]), 0, torch.sin(angles[1])], | |
| [0, 1, 0], | |
| [-torch.sin(angles[1]), 0, torch.cos(angles[1])] | |
| ], dtype=pos.dtype) | |
| # Rotation around z-axis | |
| Rz = torch.tensor([ | |
| [torch.cos(angles[2]), -torch.sin(angles[2]), 0], | |
| [torch.sin(angles[2]), torch.cos(angles[2]), 0], | |
| [0, 0, 1] | |
| ], dtype=pos.dtype) | |
| # Combined rotation | |
| R = Rz @ Ry @ Rx | |
| return pos @ R.T | |
| def denormalize(self, normalized_values: torch.Tensor) -> torch.Tensor: | |
| """Denormalize predictions back to original scale.""" | |
| return normalized_values * self.target_std + self.target_mean | |
| def create_qm9_dataloaders( | |
| root: str = "data/qm9", | |
| hf_dataset_path: str = "yairschiff/qm9", | |
| target_property: str = "homo", | |
| batch_size: int = 32, | |
| num_workers: int = 4, | |
| normalize_targets: bool = True, | |
| rotation_augmentation: bool = True, | |
| train_split: float = 0.8, | |
| val_split: float = 0.1, | |
| ) -> Dict[str, torch.utils.data.DataLoader]: | |
| """ | |
| Create train/val/test dataloaders for QM9. | |
| Args: | |
| root: Root directory for dataset | |
| hf_dataset_path: Hugging Face dataset path | |
| target_property: Target property to predict | |
| batch_size: Batch size | |
| num_workers: Number of data loading workers | |
| normalize_targets: Whether to normalize targets | |
| rotation_augmentation: Whether to use rotation augmentation | |
| train_split: Fraction for training (default: 0.8) | |
| val_split: Fraction for validation (default: 0.1) | |
| Returns: | |
| Dictionary with 'train', 'val', 'test' dataloaders | |
| """ | |
| from torch_geometric.loader import DataLoader | |
| from torch.utils.data import random_split | |
| # Load the full dataset (only 'train' split exists in HF) | |
| print("Loading full QM9 dataset...") | |
| full_dataset = QM9Dataset( | |
| root=root, | |
| hf_dataset_path=hf_dataset_path, | |
| target_property=target_property, | |
| split="train" # HF only has 'train' split | |
| ) | |
| # Calculate split sizes | |
| total_size = len(full_dataset) | |
| train_size = int(train_split * total_size) | |
| val_size = int(val_split * total_size) | |
| test_size = total_size - train_size - val_size | |
| print(f"Splitting dataset: train={train_size}, val={val_size}, test={test_size}") | |
| # Random split | |
| train_dataset, val_dataset, test_dataset = random_split( | |
| full_dataset, | |
| [train_size, val_size, test_size], | |
| generator=torch.Generator().manual_seed(42) | |
| ) | |
| # Wrap datasets - compute normalization from TRAINING set only! | |
| train_wrapped = MolecularDatasetWrapper( | |
| train_dataset, | |
| normalize_targets=normalize_targets, | |
| rotation_augmentation=rotation_augmentation | |
| ) | |
| # Val and test use TRAINING set's normalization stats | |
| val_wrapped = MolecularDatasetWrapper( | |
| val_dataset, | |
| normalize_targets=normalize_targets, | |
| rotation_augmentation=False, # No augmentation for validation | |
| target_mean=train_wrapped.target_mean if normalize_targets else None, | |
| target_std=train_wrapped.target_std if normalize_targets else None, | |
| ) | |
| test_wrapped = MolecularDatasetWrapper( | |
| test_dataset, | |
| normalize_targets=normalize_targets, | |
| rotation_augmentation=False, # No augmentation for test | |
| target_mean=train_wrapped.target_mean if normalize_targets else None, | |
| target_std=train_wrapped.target_std if normalize_targets else None, | |
| ) | |
| # Create dataloaders | |
| train_loader = DataLoader( | |
| train_wrapped, | |
| batch_size=batch_size, | |
| shuffle=True, | |
| num_workers=num_workers, | |
| pin_memory=True | |
| ) | |
| val_loader = DataLoader( | |
| val_wrapped, | |
| batch_size=batch_size, | |
| shuffle=False, | |
| num_workers=num_workers, | |
| pin_memory=True | |
| ) | |
| test_loader = DataLoader( | |
| test_wrapped, | |
| batch_size=batch_size, | |
| shuffle=False, | |
| num_workers=num_workers, | |
| pin_memory=True | |
| ) | |
| return { | |
| 'train': train_loader, | |
| 'val': val_loader, | |
| 'test': test_loader, | |
| 'train_dataset': train_wrapped, # For denormalization | |
| } | |