"""PyTorch Dataset and DataLoader for structural mechanics data. Loads Parquet files, applies normalization, and prepares tensors for training. Log-space targets for stress and deflection because values span Pa to GPa. """ from pathlib import Path from typing import Optional import numpy as np import pandas as pd import torch from torch.utils.data import DataLoader, Dataset from src.data.schema import SafetyCategory from src.models.normalization import LogTransformStandardizer # Numeric features used as model input (order matters for normalization) NUMERIC_FEATURES = [ "length", "width", "height", "inner_radius", "outer_radius", "thickness", "elastic_modulus", "poisson_ratio", "yield_strength", "density", "point_load", "distributed_load", "internal_pressure", "pressure", "moment_of_inertia", "section_modulus", "cross_section_area", ] SAFETY_CLASS_MAP = { SafetyCategory.SAFE.value: 0, SafetyCategory.MARGINAL.value: 1, SafetyCategory.FAILURE.value: 2, } class StructuralMechanicsDataset(Dataset): """Dataset for structural analysis surrogate training.""" def __init__( self, parquet_path: Path, normalizer: LogTransformStandardizer, fit_normalizer: bool = False, ) -> None: self.df = pd.read_parquet(parquet_path) # Extract features as dict of arrays features = {} for col in NUMERIC_FEATURES: if col in self.df.columns: features[col] = self.df[col].values.astype(np.float64) config_ids = self.df["config_id"].values if fit_normalizer: normalizer.fit(features, config_ids) # Transform inputs self.X = normalizer.transform(features, config_ids) # Targets in log-space stress = self.df["max_stress"].values deflection = self.df["max_deflection"].values # Clamp to avoid log(0) stress = np.where(stress > 0, stress, 1e-30) deflection = np.where(deflection > 0, deflection, 1e-30) self.log_stress = torch.from_numpy(np.log10(stress).astype(np.float32)) self.log_deflection = torch.from_numpy(np.log10(deflection).astype(np.float32)) # Log yield strength for physics loss consistency check yield_strength = self.df["yield_strength"].values self.log_yield = torch.from_numpy(np.log10(yield_strength).astype(np.float32)) # Safety category as class index safety_classes = self.df["safety_category"].map(SAFETY_CLASS_MAP).values self.safety_class = torch.from_numpy(safety_classes.astype(np.int64)) def __len__(self) -> int: return len(self.df) def __getitem__(self, idx: int) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: targets = { "log_stress": self.log_stress[idx], "log_deflection": self.log_deflection[idx], "log_yield_strength": self.log_yield[idx], "safety_class": self.safety_class[idx], } return self.X[idx], targets def create_dataloaders( data_dir: Path, normalizer: LogTransformStandardizer, batch_size: int = 512, num_workers: int = 0, ) -> tuple[DataLoader, DataLoader, DataLoader]: """Create train/val/test dataloaders. Fits the normalizer on the training set only (no data leakage). """ train_ds = StructuralMechanicsDataset( data_dir / "train.parquet", normalizer, fit_normalizer=True, ) val_ds = StructuralMechanicsDataset( data_dir / "validation.parquet", normalizer, ) test_ds = StructuralMechanicsDataset( data_dir / "test.parquet", normalizer, ) train_loader = DataLoader( train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True, ) val_loader = DataLoader( val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True, ) test_loader = DataLoader( test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True, ) return train_loader, val_loader, test_loader