Galaxy Chirality Classifier v2
3-class galaxy chirality model trained on Galaxy Zoo citizen science labels + CE-ResNet benchmarks.
Model Details
- Architecture: ViT-Small (timm vit_small_patch16_224) + 3-class head (CW / CCW / NOT_SPIRAL)
- Training data: 26,626 images (GZ1 CW/CCW labels + CE-ResNet confident spirals + synthetic hard negatives)
- Validation accuracy: 93.7% (3-class)
- Per-class: CW 94.9%, CCW 91.3%, NOT_SPIRAL 99.4%
- Flip-equivariance: CW fraction = 0.5012 with test-time equivariant averaging (matches CE-ResNet 0.5013)
- Bias hardening: 8/8 tests PASS (flip-swap, rotation, artifacts, perturbation, leakage, hemispheric, calibration, CW balance)
- External benchmark: 91.5% agreement with CE-ResNet on 23K matched galaxies
Usage
import torch, timm, torch.nn as nn
from torchvision import transforms
from PIL import Image
class Head(nn.Module):
def __init__(self):
super().__init__()
self.h = nn.Sequential(
nn.LayerNorm(384), nn.Linear(384,512), nn.GELU(), nn.Dropout(0.3),
nn.Linear(512,256), nn.GELU(), nn.Dropout(0.2), nn.Linear(256,3))
def forward(self, x): return self.h(x)
encoder = timm.create_model("vit_small_patch16_224", pretrained=False, num_classes=0)
head = Head()
ckpt = torch.load("chirality_model_v2_best.pt", weights_only=True)
encoder.load_state_dict(ckpt["enc"])
head.load_state_dict(ckpt["head"])
model = nn.Sequential(encoder, head).eval()
tfm = transforms.Compose([
transforms.Resize((224,224)), transforms.ToTensor(),
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
img = Image.open("galaxy.jpg")
with torch.no_grad():
probs = torch.softmax(model(tfm(img).unsqueeze(0)), dim=1)[0]
classes = ["CW", "CCW", "NOT_SPIRAL"]
print(f"Prediction: {classes[probs.argmax()]} ({probs.max():.1%})")
Training Methodology
- Flip-equivariance consistency loss (explicit MSE penalty for asymmetric predictions)
- Chirality-aware augmentation (horizontal flip swaps CW/CCW label)
- Heavy rotation augmentation (0-360 degrees)
- Hard negative mining (blank sky, scrambled images, ellipticals as NOT_SPIRAL)
- 6 encoder blocks unfrozen for end-to-end fine-tuning
Catalog
Running inference on 8.47M galaxies from Smith42/galaxies (DESI Legacy Survey DR8).
Author
Houston Golden (BigBounce Research)
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support