""" src/dataset.py -------------- Galaxy Zoo 2 dataset loader for hierarchical probabilistic regression. The GZ2 decision tree has 11 questions (t01-t11) with 37 total answer columns. Each question is a conditional probability vector — not independent regression targets. Hierarchy (parent answer -> child question): t01_a02 (features/disk) -> t02, t03, t04, t05, t06 t02_a05 (not edge-on) -> t03, t04 t04_a08 (has spiral) -> t10, t11 t06_a14 (odd feature) -> t08 t01_a01 (smooth) -> t07 t02_a04 (edge-on) -> t09 References ---------- Willett et al. (2013), MNRAS 435, 2835 Hart et al. (2016), MNRAS 461, 3663 """ import math import logging from pathlib import Path import numpy as np import pandas as pd import torch from torch.utils.data import Dataset, DataLoader from torchvision import transforms from PIL import Image from omegaconf import DictConfig log = logging.getLogger(__name__) # ───────────────────────────────────────────────────────────── # GZ2 decision tree definition # ───────────────────────────────────────────────────────────── LABEL_COLUMNS = [ # t01: smooth or features? "t01_smooth_or_features_a01_smooth_debiased", "t01_smooth_or_features_a02_features_or_disk_debiased", "t01_smooth_or_features_a03_star_or_artifact_debiased", # t02: edge-on? "t02_edgeon_a04_yes_debiased", "t02_edgeon_a05_no_debiased", # t03: bar? "t03_bar_a06_bar_debiased", "t03_bar_a07_no_bar_debiased", # t04: spiral? "t04_spiral_a08_spiral_debiased", "t04_spiral_a09_no_spiral_debiased", # t05: bulge prominence "t05_bulge_prominence_a10_no_bulge_debiased", "t05_bulge_prominence_a11_just_noticeable_debiased", "t05_bulge_prominence_a12_obvious_debiased", "t05_bulge_prominence_a13_dominant_debiased", # t06: odd feature? "t06_odd_a14_yes_debiased", "t06_odd_a15_no_debiased", # t07: roundedness (smooth galaxies) "t07_rounded_a16_completely_round_debiased", "t07_rounded_a17_in_between_debiased", "t07_rounded_a18_cigar_shaped_debiased", # t08: odd feature type "t08_odd_feature_a19_ring_debiased", "t08_odd_feature_a20_lens_or_arc_debiased", "t08_odd_feature_a21_disturbed_debiased", "t08_odd_feature_a22_irregular_debiased", "t08_odd_feature_a23_other_debiased", "t08_odd_feature_a24_merger_debiased", "t08_odd_feature_a38_dust_lane_debiased", # t09: bulge shape (edge-on only) "t09_bulge_shape_a25_rounded_debiased", "t09_bulge_shape_a26_boxy_debiased", "t09_bulge_shape_a27_no_bulge_debiased", # t10: arms winding "t10_arms_winding_a28_tight_debiased", "t10_arms_winding_a29_medium_debiased", "t10_arms_winding_a30_loose_debiased", # t11: arms number "t11_arms_number_a31_1_debiased", "t11_arms_number_a32_2_debiased", "t11_arms_number_a33_3_debiased", "t11_arms_number_a34_4_debiased", "t11_arms_number_a36_more_than_4_debiased", "t11_arms_number_a37_cant_tell_debiased", ] # Slice indices into LABEL_COLUMNS for each question group. QUESTION_GROUPS = { "t01": (0, 3), "t02": (3, 5), "t03": (5, 7), "t04": (7, 9), "t05": (9, 13), "t06": (13, 15), "t07": (15, 18), "t08": (18, 25), "t09": (25, 28), "t10": (28, 31), "t11": (31, 37), } # Parent answer column for hierarchical branch weighting. # w_q = vote fraction of the parent answer that unlocks question q. # t01 is the root question; its weight is always 1.0. QUESTION_PARENT_COL = { "t01": None, "t02": "t01_smooth_or_features_a02_features_or_disk_debiased", "t03": "t02_edgeon_a05_no_debiased", "t04": "t02_edgeon_a05_no_debiased", "t05": "t01_smooth_or_features_a02_features_or_disk_debiased", "t06": "t01_smooth_or_features_a02_features_or_disk_debiased", "t07": "t01_smooth_or_features_a01_smooth_debiased", "t08": "t06_odd_a14_yes_debiased", "t09": "t02_edgeon_a04_yes_debiased", "t10": "t04_spiral_a08_spiral_debiased", "t11": "t04_spiral_a08_spiral_debiased", } N_LABELS = len(LABEL_COLUMNS) # 37 # ───────────────────────────────────────────────────────────── # Image transforms # ───────────────────────────────────────────────────────────── def get_transforms(image_size: int, split: str) -> transforms.Compose: """ Training: random flips + rotations (galaxies have no preferred orientation), colour jitter (instrument variation), ImageNet normalisation. Val/Test: resize only, ImageNet normalisation. """ mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] if split == "train": return transforms.Compose([ transforms.Resize((image_size + 16, image_size + 16)), transforms.RandomCrop(image_size), transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(), transforms.RandomRotation(180), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std), ]) else: return transforms.Compose([ transforms.Resize((image_size, image_size)), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std), ]) # ───────────────────────────────────────────────────────────── # Dataset # ───────────────────────────────────────────────────────────── class GalaxyZoo2Dataset(Dataset): """ PyTorch Dataset for Galaxy Zoo 2. Returns ------- image : FloatTensor [3, H, W] normalised galaxy image targets : FloatTensor [37] vote fraction vector weights : FloatTensor [11] per-question hierarchical weights image_id : int dr7objid for traceability """ def __init__(self, df: pd.DataFrame, image_dir: str, transform): self.df = df.reset_index(drop=True) self.image_dir = Path(image_dir) self.transform = transform self.labels = self.df[LABEL_COLUMNS].values.astype(np.float32) self.weights = self._compute_weights() self.image_ids = self.df["dr7objid"].tolist() def _compute_weights(self) -> np.ndarray: n = len(self.df) q_names = list(QUESTION_GROUPS.keys()) weights = np.ones((n, len(q_names)), dtype=np.float32) for q_idx, q_name in enumerate(q_names): parent_col = QUESTION_PARENT_COL[q_name] if parent_col is not None: weights[:, q_idx] = self.df[parent_col].values.astype(np.float32) return weights def __len__(self) -> int: return len(self.df) def __getitem__(self, idx: int): image_id = self.image_ids[idx] img_path = self.image_dir / f"{image_id}.jpg" try: image = Image.open(img_path).convert("RGB") except FileNotFoundError: raise FileNotFoundError( f"Image not found: {img_path}. " f"Check dr7objid {image_id} has a matching .jpg file." ) image = self.transform(image) targets = torch.from_numpy(self.labels[idx]) weights = torch.from_numpy(self.weights[idx]) return image, targets, weights, image_id # ───────────────────────────────────────────────────────────── # DataLoader factory # ───────────────────────────────────────────────────────────── def build_dataloaders(cfg: DictConfig): """Build train / val / test DataLoaders from the labels parquet.""" log.info("Loading parquet: %s", cfg.data.parquet_path) df = pd.read_parquet(cfg.data.parquet_path) missing = [c for c in LABEL_COLUMNS if c not in df.columns] if missing: raise ValueError(f"Missing columns in parquet: {missing}") if cfg.data.n_samples is not None: n = int(cfg.data.n_samples) log.info("Using subset of %d samples (full dataset: %d)", n, len(df)) df = df.sample(n=n, random_state=cfg.seed).reset_index(drop=True) else: log.info("Using full dataset: %d samples", len(df)) rng = np.random.default_rng(cfg.seed) idx = rng.permutation(len(df)) n = len(df) n_train = math.floor(cfg.data.train_frac * n) n_val = math.floor(cfg.data.val_frac * n) train_idx = idx[:n_train] val_idx = idx[n_train : n_train + n_val] test_idx = idx[n_train + n_val :] log.info("Split — train: %d val: %d test: %d", len(train_idx), len(val_idx), len(test_idx)) image_size = cfg.data.image_size train_ds = GalaxyZoo2Dataset( df.iloc[train_idx], cfg.data.image_dir, get_transforms(image_size, "train")) val_ds = GalaxyZoo2Dataset( df.iloc[val_idx], cfg.data.image_dir, get_transforms(image_size, "val")) test_ds = GalaxyZoo2Dataset( df.iloc[test_idx], cfg.data.image_dir, get_transforms(image_size, "test")) common = dict( batch_size = cfg.training.batch_size, num_workers = cfg.data.num_workers, pin_memory = cfg.data.pin_memory, persistent_workers = getattr(cfg.data, "persistent_workers", True), prefetch_factor = getattr(cfg.data, "prefetch_factor", 4), drop_last = False, ) train_loader = DataLoader(train_ds, shuffle=True, **common) val_loader = DataLoader(val_ds, shuffle=False, **common) test_loader = DataLoader(test_ds, shuffle=False, **common) return train_loader, val_loader, test_loader