Spaces:
Running
Running
| """PyTorch Dataset and DataLoader for structural mechanics data. | |
| Loads Parquet files, applies normalization, and prepares tensors for training. | |
| Log-space targets for stress and deflection because values span Pa to GPa. | |
| """ | |
| from pathlib import Path | |
| from typing import Optional | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| from torch.utils.data import DataLoader, Dataset | |
| from src.data.schema import SafetyCategory | |
| from src.models.normalization import LogTransformStandardizer | |
| # Numeric features used as model input (order matters for normalization) | |
| NUMERIC_FEATURES = [ | |
| "length", "width", "height", "inner_radius", "outer_radius", "thickness", | |
| "elastic_modulus", "poisson_ratio", "yield_strength", "density", | |
| "point_load", "distributed_load", "internal_pressure", "pressure", | |
| "moment_of_inertia", "section_modulus", "cross_section_area", | |
| ] | |
| SAFETY_CLASS_MAP = { | |
| SafetyCategory.SAFE.value: 0, | |
| SafetyCategory.MARGINAL.value: 1, | |
| SafetyCategory.FAILURE.value: 2, | |
| } | |
| class StructuralMechanicsDataset(Dataset): | |
| """Dataset for structural analysis surrogate training.""" | |
| def __init__( | |
| self, | |
| parquet_path: Path, | |
| normalizer: LogTransformStandardizer, | |
| fit_normalizer: bool = False, | |
| ) -> None: | |
| self.df = pd.read_parquet(parquet_path) | |
| # Extract features as dict of arrays | |
| features = {} | |
| for col in NUMERIC_FEATURES: | |
| if col in self.df.columns: | |
| features[col] = self.df[col].values.astype(np.float64) | |
| config_ids = self.df["config_id"].values | |
| if fit_normalizer: | |
| normalizer.fit(features, config_ids) | |
| # Transform inputs | |
| self.X = normalizer.transform(features, config_ids) | |
| # Targets in log-space | |
| stress = self.df["max_stress"].values | |
| deflection = self.df["max_deflection"].values | |
| # Clamp to avoid log(0) | |
| stress = np.where(stress > 0, stress, 1e-30) | |
| deflection = np.where(deflection > 0, deflection, 1e-30) | |
| self.log_stress = torch.from_numpy(np.log10(stress).astype(np.float32)) | |
| self.log_deflection = torch.from_numpy(np.log10(deflection).astype(np.float32)) | |
| # Log yield strength for physics loss consistency check | |
| yield_strength = self.df["yield_strength"].values | |
| self.log_yield = torch.from_numpy(np.log10(yield_strength).astype(np.float32)) | |
| # Safety category as class index | |
| safety_classes = self.df["safety_category"].map(SAFETY_CLASS_MAP).values | |
| self.safety_class = torch.from_numpy(safety_classes.astype(np.int64)) | |
| def __len__(self) -> int: | |
| return len(self.df) | |
| def __getitem__(self, idx: int) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: | |
| targets = { | |
| "log_stress": self.log_stress[idx], | |
| "log_deflection": self.log_deflection[idx], | |
| "log_yield_strength": self.log_yield[idx], | |
| "safety_class": self.safety_class[idx], | |
| } | |
| return self.X[idx], targets | |
| def create_dataloaders( | |
| data_dir: Path, | |
| normalizer: LogTransformStandardizer, | |
| batch_size: int = 512, | |
| num_workers: int = 0, | |
| ) -> tuple[DataLoader, DataLoader, DataLoader]: | |
| """Create train/val/test dataloaders. | |
| Fits the normalizer on the training set only (no data leakage). | |
| """ | |
| train_ds = StructuralMechanicsDataset( | |
| data_dir / "train.parquet", | |
| normalizer, | |
| fit_normalizer=True, | |
| ) | |
| val_ds = StructuralMechanicsDataset( | |
| data_dir / "validation.parquet", | |
| normalizer, | |
| ) | |
| test_ds = StructuralMechanicsDataset( | |
| data_dir / "test.parquet", | |
| normalizer, | |
| ) | |
| train_loader = DataLoader( | |
| train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, | |
| pin_memory=True, drop_last=True, | |
| ) | |
| val_loader = DataLoader( | |
| val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, | |
| pin_memory=True, | |
| ) | |
| test_loader = DataLoader( | |
| test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, | |
| pin_memory=True, | |
| ) | |
| return train_loader, val_loader, test_loader | |