Solaris Small Patch 8

This repository contains a Solaris-Small checkpoint trained for 12-hour multi-wavelength solar forecasting, following the Solaris pretraining setup from Solaris: A Foundation Model of the Sun.

This checkpoint supersedes the earlier patch-size-8 upload. It was trained with a chronological 80/10/10 split, a 24-hour guard band at split boundaries, AdamW weight decay 0.05, and no EMA weights.

The checkpoint was trained on hrrsmjd/AIA_12hour_512x512 for 7750 optimizer steps using two history frames (t-12h, t) to predict all eight pretraining wavelengths at t+12h.

Files

  • solaris_small_patch8_model_state_dict.pt: reusable PyTorch checkpoint containing model_state_dict, learned normalization coefficients, wavelengths, scale factors, patch size, seed, split metadata, test metrics, training step, and final training loss.
  • config.json: lightweight metadata for reconstructing the model and normalization.
  • assets/solaris_small_patch8_test0_prediction.png: example qualitative test prediction plot.
  • eval/solaris_pretrain_p8_chronosplit_wd005_noema_test_mse_subset_0941.md: full chronological test-split raw-scale MSE/RMSE/MAE report.

Example Plot

The plot below shows one test sample with rows for input t-12h, input t, target t+12h, prediction t+12h, and prediction - target across all eight wavelengths.

Solaris Small Patch 8 test prediction

Model Details

  • Architecture: SolarisSmall
  • Patch size: 8
  • Embedding dimension: 256
  • Encoder depths: (2, 6, 2)
  • Decoder depths: (2, 6, 2)
  • Output wavelengths: 0094, 0131, 0171, 0193, 0211, 0304, 0335, 1600
  • Training dataset: hrrsmjd/AIA_12hour_512x512
  • Training target: 12-hour forecast
  • Split scheme: chronological 80/10/10 over valid samples with a 24-hour boundary guard band
  • Split counts: train 7541, validation 939, test 941
  • Training budget: 7750 optimizer steps, batch size 8, gradient accumulation 4
  • Optimizer: AdamW, lr=5e-4, cosine decay to 5e-5, weight_decay=0.05, betas (0.9, 0.95)
  • EMA: disabled
  • Seed: 42

Loading

import torch

from solaris.model.solaris import SolarisSmall

checkpoint = torch.load("solaris_small_patch8_model_state_dict.pt", map_location="cpu", weights_only=False)

model = SolarisSmall(
    out_levels=len(checkpoint["wavelengths"]),
    patch_size=checkpoint["patch_size"],
)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

scale_factors = torch.tensor(checkpoint["scale_factors"], dtype=torch.float32)
norm_coeff_1 = checkpoint["norm_coeff_1"]
norm_coeff_2 = checkpoint["norm_coeff_2"]

Inputs should be normalized with the Solaris transform used during training. Model outputs are normalized intensities; multiply by the per-wavelength scale factors before comparing to raw-intensity targets.

Test Metrics

Metrics below use all 941 chronological test samples and are computed on the raw intensity scale.

Wavelength (A) MSE RMSE MAE
0094 7.08928 2.66257 0.310692
0131 147.528 12.1461 1.44793
0171 13730.9 117.179 52.9684
0193 22906.6 151.349 62.3088
0211 7752.18 88.0464 35.0413
0304 1556.87 39.4572 14.2213
0335 48.0333 6.9306 2.0811
1600 145.126 12.0468 7.3929
Mean 5786.78 53.7272 21.9716

Training Notes

Scale factors were computed as half the average per-image maximum over unique train-split timestamps:

[57.944149103084534, 214.99738922760267, 1590.2998402078304, 2397.489401917806, 1080.261734048243, 830.778793198845, 104.45557294825853, 274.65685334356664]

The final logged training-batch metrics at step 7750 were:

weighted MAE: 0.007597
mean raw RMSE: 19.576
per-wavelength raw RMSE: [1.210, 4.034, 49.249, 53.399, 20.770, 17.612, 1.653, 8.684]
Downloads last month
46
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train hrrsmjd/solaris_small_patch8

Paper for hrrsmjd/solaris_small_patch8