license: mit
library_name: pytorch
pipeline_tag: image-to-image
tags:
- solar-physics
- solar-forecasting
- swin-transformer
- pytorch
- aia
- patch-size-8
datasets:
- hrrsmjd/AIA_12hour_512x512
Solaris Small Patch 8
This repository contains a Solaris-Small checkpoint trained for 12-hour multi-wavelength solar forecasting, following the Solaris pretraining setup from Solaris: A Foundation Model of the Sun.
This checkpoint supersedes the earlier patch-size-8 upload. It was trained with a chronological 80/10/10 split, a 24-hour guard band at split boundaries, AdamW weight decay 0.05, and no EMA weights.
The checkpoint was trained on hrrsmjd/AIA_12hour_512x512 for 7750 optimizer steps using two history frames (t-12h, t) to predict all eight pretraining wavelengths at t+12h.
Files
solaris_small_patch8_model_state_dict.pt: reusable PyTorch checkpoint containingmodel_state_dict, learned normalization coefficients, wavelengths, scale factors, patch size, seed, split metadata, test metrics, training step, and final training loss.config.json: lightweight metadata for reconstructing the model and normalization.assets/solaris_small_patch8_test0_prediction.png: example qualitative test prediction plot.eval/solaris_pretrain_p8_chronosplit_wd005_noema_test_mse_subset_0941.md: full chronological test-split raw-scale MSE/RMSE/MAE report.
Example Plot
The plot below shows one test sample with rows for input t-12h, input t, target t+12h, prediction t+12h, and prediction - target across all eight wavelengths.
Model Details
- Architecture:
SolarisSmall - Patch size:
8 - Embedding dimension:
256 - Encoder depths:
(2, 6, 2) - Decoder depths:
(2, 6, 2) - Output wavelengths:
0094,0131,0171,0193,0211,0304,0335,1600 - Training dataset:
hrrsmjd/AIA_12hour_512x512 - Training target: 12-hour forecast
- Split scheme: chronological 80/10/10 over valid samples with a 24-hour boundary guard band
- Split counts: train
7541, validation939, test941 - Training budget: 7750 optimizer steps, batch size 8, gradient accumulation 4
- Optimizer: AdamW,
lr=5e-4, cosine decay to5e-5,weight_decay=0.05, betas(0.9, 0.95) - EMA: disabled
- Seed:
42
Loading
import torch
from solaris.model.solaris import SolarisSmall
checkpoint = torch.load("solaris_small_patch8_model_state_dict.pt", map_location="cpu", weights_only=False)
model = SolarisSmall(
out_levels=len(checkpoint["wavelengths"]),
patch_size=checkpoint["patch_size"],
)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
scale_factors = torch.tensor(checkpoint["scale_factors"], dtype=torch.float32)
norm_coeff_1 = checkpoint["norm_coeff_1"]
norm_coeff_2 = checkpoint["norm_coeff_2"]
Inputs should be normalized with the Solaris transform used during training. Model outputs are normalized intensities; multiply by the per-wavelength scale factors before comparing to raw-intensity targets.
Test Metrics
Metrics below use all 941 chronological test samples and are computed on the raw intensity scale.
| Wavelength (A) | MSE | RMSE | MAE |
|---|---|---|---|
| 0094 | 7.08928 | 2.66257 | 0.310692 |
| 0131 | 147.528 | 12.1461 | 1.44793 |
| 0171 | 13730.9 | 117.179 | 52.9684 |
| 0193 | 22906.6 | 151.349 | 62.3088 |
| 0211 | 7752.18 | 88.0464 | 35.0413 |
| 0304 | 1556.87 | 39.4572 | 14.2213 |
| 0335 | 48.0333 | 6.9306 | 2.0811 |
| 1600 | 145.126 | 12.0468 | 7.3929 |
| Mean | 5786.78 | 53.7272 | 21.9716 |
Training Notes
Scale factors were computed as half the average per-image maximum over unique train-split timestamps:
[57.944149103084534, 214.99738922760267, 1590.2998402078304, 2397.489401917806, 1080.261734048243, 830.778793198845, 104.45557294825853, 274.65685334356664]
The final logged training-batch metrics at step 7750 were:
weighted MAE: 0.007597
mean raw RMSE: 19.576
per-wavelength raw RMSE: [1.210, 4.034, 49.249, 53.399, 20.770, 17.612, 1.653, 8.684]
