eshwar-gz2-api / src /dataset.py
sreshwarprasad's picture
Upload folder using huggingface_hub
e36eee4 verified
"""
src/dataset.py
--------------
Galaxy Zoo 2 dataset loader for hierarchical probabilistic regression.
The GZ2 decision tree has 11 questions (t01-t11) with 37 total answer
columns. Each question is a conditional probability vector β€” not
independent regression targets.
Hierarchy (parent answer -> child question):
t01_a02 (features/disk) -> t02, t03, t04, t05, t06
t02_a05 (not edge-on) -> t03, t04
t04_a08 (has spiral) -> t10, t11
t06_a14 (odd feature) -> t08
t01_a01 (smooth) -> t07
t02_a04 (edge-on) -> t09
References
----------
Willett et al. (2013), MNRAS 435, 2835
Hart et al. (2016), MNRAS 461, 3663
"""
import math
import logging
from pathlib import Path
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from omegaconf import DictConfig
log = logging.getLogger(__name__)
# ─────────────────────────────────────────────────────────────
# GZ2 decision tree definition
# ─────────────────────────────────────────────────────────────
LABEL_COLUMNS = [
# t01: smooth or features?
"t01_smooth_or_features_a01_smooth_debiased",
"t01_smooth_or_features_a02_features_or_disk_debiased",
"t01_smooth_or_features_a03_star_or_artifact_debiased",
# t02: edge-on?
"t02_edgeon_a04_yes_debiased",
"t02_edgeon_a05_no_debiased",
# t03: bar?
"t03_bar_a06_bar_debiased",
"t03_bar_a07_no_bar_debiased",
# t04: spiral?
"t04_spiral_a08_spiral_debiased",
"t04_spiral_a09_no_spiral_debiased",
# t05: bulge prominence
"t05_bulge_prominence_a10_no_bulge_debiased",
"t05_bulge_prominence_a11_just_noticeable_debiased",
"t05_bulge_prominence_a12_obvious_debiased",
"t05_bulge_prominence_a13_dominant_debiased",
# t06: odd feature?
"t06_odd_a14_yes_debiased",
"t06_odd_a15_no_debiased",
# t07: roundedness (smooth galaxies)
"t07_rounded_a16_completely_round_debiased",
"t07_rounded_a17_in_between_debiased",
"t07_rounded_a18_cigar_shaped_debiased",
# t08: odd feature type
"t08_odd_feature_a19_ring_debiased",
"t08_odd_feature_a20_lens_or_arc_debiased",
"t08_odd_feature_a21_disturbed_debiased",
"t08_odd_feature_a22_irregular_debiased",
"t08_odd_feature_a23_other_debiased",
"t08_odd_feature_a24_merger_debiased",
"t08_odd_feature_a38_dust_lane_debiased",
# t09: bulge shape (edge-on only)
"t09_bulge_shape_a25_rounded_debiased",
"t09_bulge_shape_a26_boxy_debiased",
"t09_bulge_shape_a27_no_bulge_debiased",
# t10: arms winding
"t10_arms_winding_a28_tight_debiased",
"t10_arms_winding_a29_medium_debiased",
"t10_arms_winding_a30_loose_debiased",
# t11: arms number
"t11_arms_number_a31_1_debiased",
"t11_arms_number_a32_2_debiased",
"t11_arms_number_a33_3_debiased",
"t11_arms_number_a34_4_debiased",
"t11_arms_number_a36_more_than_4_debiased",
"t11_arms_number_a37_cant_tell_debiased",
]
# Slice indices into LABEL_COLUMNS for each question group.
QUESTION_GROUPS = {
"t01": (0, 3),
"t02": (3, 5),
"t03": (5, 7),
"t04": (7, 9),
"t05": (9, 13),
"t06": (13, 15),
"t07": (15, 18),
"t08": (18, 25),
"t09": (25, 28),
"t10": (28, 31),
"t11": (31, 37),
}
# Parent answer column for hierarchical branch weighting.
# w_q = vote fraction of the parent answer that unlocks question q.
# t01 is the root question; its weight is always 1.0.
QUESTION_PARENT_COL = {
"t01": None,
"t02": "t01_smooth_or_features_a02_features_or_disk_debiased",
"t03": "t02_edgeon_a05_no_debiased",
"t04": "t02_edgeon_a05_no_debiased",
"t05": "t01_smooth_or_features_a02_features_or_disk_debiased",
"t06": "t01_smooth_or_features_a02_features_or_disk_debiased",
"t07": "t01_smooth_or_features_a01_smooth_debiased",
"t08": "t06_odd_a14_yes_debiased",
"t09": "t02_edgeon_a04_yes_debiased",
"t10": "t04_spiral_a08_spiral_debiased",
"t11": "t04_spiral_a08_spiral_debiased",
}
N_LABELS = len(LABEL_COLUMNS) # 37
# ─────────────────────────────────────────────────────────────
# Image transforms
# ─────────────────────────────────────────────────────────────
def get_transforms(image_size: int, split: str) -> transforms.Compose:
"""
Training: random flips + rotations (galaxies have no preferred orientation),
colour jitter (instrument variation), ImageNet normalisation.
Val/Test: resize only, ImageNet normalisation.
"""
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
if split == "train":
return transforms.Compose([
transforms.Resize((image_size + 16, image_size + 16)),
transforms.RandomCrop(image_size),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.RandomRotation(180),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
else:
return transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
# ─────────────────────────────────────────────────────────────
# Dataset
# ─────────────────────────────────────────────────────────────
class GalaxyZoo2Dataset(Dataset):
"""
PyTorch Dataset for Galaxy Zoo 2.
Returns
-------
image : FloatTensor [3, H, W] normalised galaxy image
targets : FloatTensor [37] vote fraction vector
weights : FloatTensor [11] per-question hierarchical weights
image_id : int dr7objid for traceability
"""
def __init__(self, df: pd.DataFrame, image_dir: str, transform):
self.df = df.reset_index(drop=True)
self.image_dir = Path(image_dir)
self.transform = transform
self.labels = self.df[LABEL_COLUMNS].values.astype(np.float32)
self.weights = self._compute_weights()
self.image_ids = self.df["dr7objid"].tolist()
def _compute_weights(self) -> np.ndarray:
n = len(self.df)
q_names = list(QUESTION_GROUPS.keys())
weights = np.ones((n, len(q_names)), dtype=np.float32)
for q_idx, q_name in enumerate(q_names):
parent_col = QUESTION_PARENT_COL[q_name]
if parent_col is not None:
weights[:, q_idx] = self.df[parent_col].values.astype(np.float32)
return weights
def __len__(self) -> int:
return len(self.df)
def __getitem__(self, idx: int):
image_id = self.image_ids[idx]
img_path = self.image_dir / f"{image_id}.jpg"
try:
image = Image.open(img_path).convert("RGB")
except FileNotFoundError:
raise FileNotFoundError(
f"Image not found: {img_path}. "
f"Check dr7objid {image_id} has a matching .jpg file."
)
image = self.transform(image)
targets = torch.from_numpy(self.labels[idx])
weights = torch.from_numpy(self.weights[idx])
return image, targets, weights, image_id
# ─────────────────────────────────────────────────────────────
# DataLoader factory
# ─────────────────────────────────────────────────────────────
def build_dataloaders(cfg: DictConfig):
"""Build train / val / test DataLoaders from the labels parquet."""
log.info("Loading parquet: %s", cfg.data.parquet_path)
df = pd.read_parquet(cfg.data.parquet_path)
missing = [c for c in LABEL_COLUMNS if c not in df.columns]
if missing:
raise ValueError(f"Missing columns in parquet: {missing}")
if cfg.data.n_samples is not None:
n = int(cfg.data.n_samples)
log.info("Using subset of %d samples (full dataset: %d)", n, len(df))
df = df.sample(n=n, random_state=cfg.seed).reset_index(drop=True)
else:
log.info("Using full dataset: %d samples", len(df))
rng = np.random.default_rng(cfg.seed)
idx = rng.permutation(len(df))
n = len(df)
n_train = math.floor(cfg.data.train_frac * n)
n_val = math.floor(cfg.data.val_frac * n)
train_idx = idx[:n_train]
val_idx = idx[n_train : n_train + n_val]
test_idx = idx[n_train + n_val :]
log.info("Split β€” train: %d val: %d test: %d",
len(train_idx), len(val_idx), len(test_idx))
image_size = cfg.data.image_size
train_ds = GalaxyZoo2Dataset(
df.iloc[train_idx], cfg.data.image_dir,
get_transforms(image_size, "train"))
val_ds = GalaxyZoo2Dataset(
df.iloc[val_idx], cfg.data.image_dir,
get_transforms(image_size, "val"))
test_ds = GalaxyZoo2Dataset(
df.iloc[test_idx], cfg.data.image_dir,
get_transforms(image_size, "test"))
common = dict(
batch_size = cfg.training.batch_size,
num_workers = cfg.data.num_workers,
pin_memory = cfg.data.pin_memory,
persistent_workers = getattr(cfg.data, "persistent_workers", True),
prefetch_factor = getattr(cfg.data, "prefetch_factor", 4),
drop_last = False,
)
train_loader = DataLoader(train_ds, shuffle=True, **common)
val_loader = DataLoader(val_ds, shuffle=False, **common)
test_loader = DataLoader(test_ds, shuffle=False, **common)
return train_loader, val_loader, test_loader