JetDDPM-SAE โ€” Checkpoints

Two stage model on quark/gluon calorimeter jet images (3-channel, 125ร—125):

  1. DDPM UNet trained to generate realistic jet images.
  2. Sparse Autoencoder on frozen UNet ups[0] activations for interpretability.

Checkpoints

File Description
diff_checkpoint.pth Trained diffusion UNet (raw state_dict)
sae_upblock_checkpoint.pth SAE trained on ups[0] activations at t=100

Loading diffusion model

import torch
from src.diffusion import Unet
from src.noise_scheduler import LinNoiseScheduler

model = Unet()
model.load_state_dict(torch.load("diff_checkpoint.pth", map_location="cpu"))
model.eval()

scheduler = LinNoiseScheduler(num_of_timesteps=1000)
# generate: start from noise and call scheduler.sample_prev_timestep()

Loading SAE

import torch
from src.sae import SparseAutoencoder

ckpt  = torch.load("sae_upblock_checkpoint.pth", map_location="cpu")
model = SparseAutoencoder(input_dim=ckpt["input_dim"], hidden_dim=1024)
model.load_state_dict(ckpt["model"])
model.eval()

# pass global-avg-pooled ups[0] activations โ†’ sparse feature vector z
recon, z = model(activation_vector)

Training details

  • Diffusion: MSE noise prediction, Adam lr=1e-4, 50 epochs, batch 32
  • SAE hook: ups[0] (first decoder block, 31ร—31, 256ch), fixed t=100
  • SAE: MSE + L1(z)ยท1e-3, Adam lr=1e-3, 50 epochs on 30k activations
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support