Galaxy Chirality Classifier v2

3-class galaxy chirality model trained on Galaxy Zoo citizen science labels + CE-ResNet benchmarks.

Model Details

  • Architecture: ViT-Small (timm vit_small_patch16_224) + 3-class head (CW / CCW / NOT_SPIRAL)
  • Training data: 26,626 images (GZ1 CW/CCW labels + CE-ResNet confident spirals + synthetic hard negatives)
  • Validation accuracy: 93.7% (3-class)
  • Per-class: CW 94.9%, CCW 91.3%, NOT_SPIRAL 99.4%
  • Flip-equivariance: CW fraction = 0.5012 with test-time equivariant averaging (matches CE-ResNet 0.5013)
  • Bias hardening: 8/8 tests PASS (flip-swap, rotation, artifacts, perturbation, leakage, hemispheric, calibration, CW balance)
  • External benchmark: 91.5% agreement with CE-ResNet on 23K matched galaxies

Usage

import torch, timm, torch.nn as nn
from torchvision import transforms
from PIL import Image

class Head(nn.Module):
    def __init__(self):
        super().__init__()
        self.h = nn.Sequential(
            nn.LayerNorm(384), nn.Linear(384,512), nn.GELU(), nn.Dropout(0.3),
            nn.Linear(512,256), nn.GELU(), nn.Dropout(0.2), nn.Linear(256,3))
    def forward(self, x): return self.h(x)

encoder = timm.create_model("vit_small_patch16_224", pretrained=False, num_classes=0)
head = Head()
ckpt = torch.load("chirality_model_v2_best.pt", weights_only=True)
encoder.load_state_dict(ckpt["enc"])
head.load_state_dict(ckpt["head"])

model = nn.Sequential(encoder, head).eval()
tfm = transforms.Compose([
    transforms.Resize((224,224)), transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])

img = Image.open("galaxy.jpg")
with torch.no_grad():
    probs = torch.softmax(model(tfm(img).unsqueeze(0)), dim=1)[0]
classes = ["CW", "CCW", "NOT_SPIRAL"]
print(f"Prediction: {classes[probs.argmax()]} ({probs.max():.1%})")

Training Methodology

  • Flip-equivariance consistency loss (explicit MSE penalty for asymmetric predictions)
  • Chirality-aware augmentation (horizontal flip swaps CW/CCW label)
  • Heavy rotation augmentation (0-360 degrees)
  • Hard negative mining (blank sky, scrambled images, ellipticals as NOT_SPIRAL)
  • 6 encoder blocks unfrozen for end-to-end fine-tuning

Catalog

Running inference on 8.47M galaxies from Smith42/galaxies (DESI Legacy Survey DR8).

Author

Houston Golden (BigBounce Research)

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Datasets used to train bamfai/galaxy-chirality-v2