JetDDPM-SAE โ Checkpoints
Two stage model on quark/gluon calorimeter jet images (3-channel, 125ร125):
- DDPM UNet trained to generate realistic jet images.
- Sparse Autoencoder on frozen UNet
ups[0]activations for interpretability.
Checkpoints
| File | Description |
|---|---|
diff_checkpoint.pth |
Trained diffusion UNet (raw state_dict) |
sae_upblock_checkpoint.pth |
SAE trained on ups[0] activations at t=100 |
Loading diffusion model
import torch
from src.diffusion import Unet
from src.noise_scheduler import LinNoiseScheduler
model = Unet()
model.load_state_dict(torch.load("diff_checkpoint.pth", map_location="cpu"))
model.eval()
scheduler = LinNoiseScheduler(num_of_timesteps=1000)
# generate: start from noise and call scheduler.sample_prev_timestep()
Loading SAE
import torch
from src.sae import SparseAutoencoder
ckpt = torch.load("sae_upblock_checkpoint.pth", map_location="cpu")
model = SparseAutoencoder(input_dim=ckpt["input_dim"], hidden_dim=1024)
model.load_state_dict(ckpt["model"])
model.eval()
# pass global-avg-pooled ups[0] activations โ sparse feature vector z
recon, z = model(activation_vector)
Training details
- Diffusion: MSE noise prediction, Adam lr=1e-4, 50 epochs, batch 32
- SAE hook:
ups[0](first decoder block, 31ร31, 256ch), fixed t=100 - SAE: MSE + L1(z)ยท1e-3, Adam lr=1e-3, 50 epochs on 30k activations
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐ Ask for provider support