Spaces:
Running
Running
| """ | |
| src/dataset.py | |
| -------------- | |
| Galaxy Zoo 2 dataset loader for hierarchical probabilistic regression. | |
| The GZ2 decision tree has 11 questions (t01-t11) with 37 total answer | |
| columns. Each question is a conditional probability vector β not | |
| independent regression targets. | |
| Hierarchy (parent answer -> child question): | |
| t01_a02 (features/disk) -> t02, t03, t04, t05, t06 | |
| t02_a05 (not edge-on) -> t03, t04 | |
| t04_a08 (has spiral) -> t10, t11 | |
| t06_a14 (odd feature) -> t08 | |
| t01_a01 (smooth) -> t07 | |
| t02_a04 (edge-on) -> t09 | |
| References | |
| ---------- | |
| Willett et al. (2013), MNRAS 435, 2835 | |
| Hart et al. (2016), MNRAS 461, 3663 | |
| """ | |
| import math | |
| import logging | |
| from pathlib import Path | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| from torch.utils.data import Dataset, DataLoader | |
| from torchvision import transforms | |
| from PIL import Image | |
| from omegaconf import DictConfig | |
| log = logging.getLogger(__name__) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # GZ2 decision tree definition | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| LABEL_COLUMNS = [ | |
| # t01: smooth or features? | |
| "t01_smooth_or_features_a01_smooth_debiased", | |
| "t01_smooth_or_features_a02_features_or_disk_debiased", | |
| "t01_smooth_or_features_a03_star_or_artifact_debiased", | |
| # t02: edge-on? | |
| "t02_edgeon_a04_yes_debiased", | |
| "t02_edgeon_a05_no_debiased", | |
| # t03: bar? | |
| "t03_bar_a06_bar_debiased", | |
| "t03_bar_a07_no_bar_debiased", | |
| # t04: spiral? | |
| "t04_spiral_a08_spiral_debiased", | |
| "t04_spiral_a09_no_spiral_debiased", | |
| # t05: bulge prominence | |
| "t05_bulge_prominence_a10_no_bulge_debiased", | |
| "t05_bulge_prominence_a11_just_noticeable_debiased", | |
| "t05_bulge_prominence_a12_obvious_debiased", | |
| "t05_bulge_prominence_a13_dominant_debiased", | |
| # t06: odd feature? | |
| "t06_odd_a14_yes_debiased", | |
| "t06_odd_a15_no_debiased", | |
| # t07: roundedness (smooth galaxies) | |
| "t07_rounded_a16_completely_round_debiased", | |
| "t07_rounded_a17_in_between_debiased", | |
| "t07_rounded_a18_cigar_shaped_debiased", | |
| # t08: odd feature type | |
| "t08_odd_feature_a19_ring_debiased", | |
| "t08_odd_feature_a20_lens_or_arc_debiased", | |
| "t08_odd_feature_a21_disturbed_debiased", | |
| "t08_odd_feature_a22_irregular_debiased", | |
| "t08_odd_feature_a23_other_debiased", | |
| "t08_odd_feature_a24_merger_debiased", | |
| "t08_odd_feature_a38_dust_lane_debiased", | |
| # t09: bulge shape (edge-on only) | |
| "t09_bulge_shape_a25_rounded_debiased", | |
| "t09_bulge_shape_a26_boxy_debiased", | |
| "t09_bulge_shape_a27_no_bulge_debiased", | |
| # t10: arms winding | |
| "t10_arms_winding_a28_tight_debiased", | |
| "t10_arms_winding_a29_medium_debiased", | |
| "t10_arms_winding_a30_loose_debiased", | |
| # t11: arms number | |
| "t11_arms_number_a31_1_debiased", | |
| "t11_arms_number_a32_2_debiased", | |
| "t11_arms_number_a33_3_debiased", | |
| "t11_arms_number_a34_4_debiased", | |
| "t11_arms_number_a36_more_than_4_debiased", | |
| "t11_arms_number_a37_cant_tell_debiased", | |
| ] | |
| # Slice indices into LABEL_COLUMNS for each question group. | |
| QUESTION_GROUPS = { | |
| "t01": (0, 3), | |
| "t02": (3, 5), | |
| "t03": (5, 7), | |
| "t04": (7, 9), | |
| "t05": (9, 13), | |
| "t06": (13, 15), | |
| "t07": (15, 18), | |
| "t08": (18, 25), | |
| "t09": (25, 28), | |
| "t10": (28, 31), | |
| "t11": (31, 37), | |
| } | |
| # Parent answer column for hierarchical branch weighting. | |
| # w_q = vote fraction of the parent answer that unlocks question q. | |
| # t01 is the root question; its weight is always 1.0. | |
| QUESTION_PARENT_COL = { | |
| "t01": None, | |
| "t02": "t01_smooth_or_features_a02_features_or_disk_debiased", | |
| "t03": "t02_edgeon_a05_no_debiased", | |
| "t04": "t02_edgeon_a05_no_debiased", | |
| "t05": "t01_smooth_or_features_a02_features_or_disk_debiased", | |
| "t06": "t01_smooth_or_features_a02_features_or_disk_debiased", | |
| "t07": "t01_smooth_or_features_a01_smooth_debiased", | |
| "t08": "t06_odd_a14_yes_debiased", | |
| "t09": "t02_edgeon_a04_yes_debiased", | |
| "t10": "t04_spiral_a08_spiral_debiased", | |
| "t11": "t04_spiral_a08_spiral_debiased", | |
| } | |
| N_LABELS = len(LABEL_COLUMNS) # 37 | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Image transforms | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_transforms(image_size: int, split: str) -> transforms.Compose: | |
| """ | |
| Training: random flips + rotations (galaxies have no preferred orientation), | |
| colour jitter (instrument variation), ImageNet normalisation. | |
| Val/Test: resize only, ImageNet normalisation. | |
| """ | |
| mean = [0.485, 0.456, 0.406] | |
| std = [0.229, 0.224, 0.225] | |
| if split == "train": | |
| return transforms.Compose([ | |
| transforms.Resize((image_size + 16, image_size + 16)), | |
| transforms.RandomCrop(image_size), | |
| transforms.RandomHorizontalFlip(), | |
| transforms.RandomVerticalFlip(), | |
| transforms.RandomRotation(180), | |
| transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=mean, std=std), | |
| ]) | |
| else: | |
| return transforms.Compose([ | |
| transforms.Resize((image_size, image_size)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=mean, std=std), | |
| ]) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Dataset | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class GalaxyZoo2Dataset(Dataset): | |
| """ | |
| PyTorch Dataset for Galaxy Zoo 2. | |
| Returns | |
| ------- | |
| image : FloatTensor [3, H, W] normalised galaxy image | |
| targets : FloatTensor [37] vote fraction vector | |
| weights : FloatTensor [11] per-question hierarchical weights | |
| image_id : int dr7objid for traceability | |
| """ | |
| def __init__(self, df: pd.DataFrame, image_dir: str, transform): | |
| self.df = df.reset_index(drop=True) | |
| self.image_dir = Path(image_dir) | |
| self.transform = transform | |
| self.labels = self.df[LABEL_COLUMNS].values.astype(np.float32) | |
| self.weights = self._compute_weights() | |
| self.image_ids = self.df["dr7objid"].tolist() | |
| def _compute_weights(self) -> np.ndarray: | |
| n = len(self.df) | |
| q_names = list(QUESTION_GROUPS.keys()) | |
| weights = np.ones((n, len(q_names)), dtype=np.float32) | |
| for q_idx, q_name in enumerate(q_names): | |
| parent_col = QUESTION_PARENT_COL[q_name] | |
| if parent_col is not None: | |
| weights[:, q_idx] = self.df[parent_col].values.astype(np.float32) | |
| return weights | |
| def __len__(self) -> int: | |
| return len(self.df) | |
| def __getitem__(self, idx: int): | |
| image_id = self.image_ids[idx] | |
| img_path = self.image_dir / f"{image_id}.jpg" | |
| try: | |
| image = Image.open(img_path).convert("RGB") | |
| except FileNotFoundError: | |
| raise FileNotFoundError( | |
| f"Image not found: {img_path}. " | |
| f"Check dr7objid {image_id} has a matching .jpg file." | |
| ) | |
| image = self.transform(image) | |
| targets = torch.from_numpy(self.labels[idx]) | |
| weights = torch.from_numpy(self.weights[idx]) | |
| return image, targets, weights, image_id | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # DataLoader factory | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def build_dataloaders(cfg: DictConfig): | |
| """Build train / val / test DataLoaders from the labels parquet.""" | |
| log.info("Loading parquet: %s", cfg.data.parquet_path) | |
| df = pd.read_parquet(cfg.data.parquet_path) | |
| missing = [c for c in LABEL_COLUMNS if c not in df.columns] | |
| if missing: | |
| raise ValueError(f"Missing columns in parquet: {missing}") | |
| if cfg.data.n_samples is not None: | |
| n = int(cfg.data.n_samples) | |
| log.info("Using subset of %d samples (full dataset: %d)", n, len(df)) | |
| df = df.sample(n=n, random_state=cfg.seed).reset_index(drop=True) | |
| else: | |
| log.info("Using full dataset: %d samples", len(df)) | |
| rng = np.random.default_rng(cfg.seed) | |
| idx = rng.permutation(len(df)) | |
| n = len(df) | |
| n_train = math.floor(cfg.data.train_frac * n) | |
| n_val = math.floor(cfg.data.val_frac * n) | |
| train_idx = idx[:n_train] | |
| val_idx = idx[n_train : n_train + n_val] | |
| test_idx = idx[n_train + n_val :] | |
| log.info("Split β train: %d val: %d test: %d", | |
| len(train_idx), len(val_idx), len(test_idx)) | |
| image_size = cfg.data.image_size | |
| train_ds = GalaxyZoo2Dataset( | |
| df.iloc[train_idx], cfg.data.image_dir, | |
| get_transforms(image_size, "train")) | |
| val_ds = GalaxyZoo2Dataset( | |
| df.iloc[val_idx], cfg.data.image_dir, | |
| get_transforms(image_size, "val")) | |
| test_ds = GalaxyZoo2Dataset( | |
| df.iloc[test_idx], cfg.data.image_dir, | |
| get_transforms(image_size, "test")) | |
| common = dict( | |
| batch_size = cfg.training.batch_size, | |
| num_workers = cfg.data.num_workers, | |
| pin_memory = cfg.data.pin_memory, | |
| persistent_workers = getattr(cfg.data, "persistent_workers", True), | |
| prefetch_factor = getattr(cfg.data, "prefetch_factor", 4), | |
| drop_last = False, | |
| ) | |
| train_loader = DataLoader(train_ds, shuffle=True, **common) | |
| val_loader = DataLoader(val_ds, shuffle=False, **common) | |
| test_loader = DataLoader(test_ds, shuffle=False, **common) | |
| return train_loader, val_loader, test_loader | |