File size: 3,959 Bytes
5108bee | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 | ---
license: mit
library_name: pytorch
pipeline_tag: image-to-image
tags:
- solar-physics
- solar-forecasting
- swin-transformer
- pytorch
- aia
- patch-size-4
datasets:
- hrrsmjd/AIA_12hour_512x512
---
# Solaris Small Patch 4
This repository contains a Solaris-Small checkpoint trained for 12-hour multi-wavelength solar forecasting, following the Solaris pretraining setup from [Solaris: A Foundation Model of the Sun](https://arxiv.org/abs/2411.16339).
This run uses patch size 4. The earlier patch-size-8 checkpoint is published separately as `hrrsmjd/solaris_small_patch8`.
The checkpoint was trained on `hrrsmjd/AIA_12hour_512x512` for 7750 optimizer steps using two history frames (`t-12h`, `t`) to predict all eight pretraining wavelengths at `t+12h`.
## Files
- `solaris_small_patch4_model_state_dict.pt`: reusable PyTorch checkpoint containing `model_state_dict`, learned normalization coefficients, wavelengths, scale factors, patch size, seed, training step, and final training loss.
- `config.json`: lightweight metadata for reconstructing the model and normalization.
- `assets/solaris_small_patch4_test0_prediction.png`: example qualitative test prediction plot.
- `eval/solaris_pretrain_paperloss_p4_ema_seed42_test_mse_subset_0352.md`: full test-split raw-scale MSE/RMSE/MAE report.
## Example Plot
The plot below shows one test sample with rows for `input t-12h`, `input t`, `target t+12h`, `prediction t+12h`, and `prediction - target` across all eight wavelengths.

## Model Details
- Architecture: `SolarisSmall`
- Patch size: `4`
- Embedding dimension: `256`
- Encoder depths: `(2, 6, 2)`
- Decoder depths: `(2, 6, 2)`
- Output wavelengths: `0094`, `0131`, `0171`, `0193`, `0211`, `0304`, `0335`, `1600`
- Training dataset: `hrrsmjd/AIA_12hour_512x512`
- Training target: 12-hour forecast
- Training budget: 7750 optimizer steps, batch size 8, gradient accumulation 4
- Seed: `42`
## Loading
```python
import torch
from solaris.model.solaris import SolarisSmall
checkpoint = torch.load("solaris_small_patch4_model_state_dict.pt", map_location="cpu", weights_only=False)
model = SolarisSmall(
out_levels=len(checkpoint["wavelengths"]),
patch_size=checkpoint["patch_size"],
)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
scale_factors = torch.tensor(checkpoint["scale_factors"], dtype=torch.float32)
norm_coeff_1 = checkpoint["norm_coeff_1"]
norm_coeff_2 = checkpoint["norm_coeff_2"]
```
Inputs should be normalized with the Solaris transform used during training. Model outputs are normalized intensities; multiply by the per-wavelength scale factors before comparing to raw-intensity targets.
## Test Metrics
Metrics below use all 352 test samples and are computed on the raw intensity scale. Regular final weights are recommended; EMA weights from the training checkpoint were worse on the full test split and are not included in this model-state checkpoint.
| Wavelength (A) | MSE | RMSE | MAE |
|---:|---:|---:|---:|
| 0094 | 8.87581 | 2.97923 | 0.240353 |
| 0131 | 243.534 | 15.6056 | 1.446 |
| 0171 | 9067.35 | 95.2226 | 37.4301 |
| 0193 | 18337.8 | 135.417 | 56.6422 |
| 0211 | 3811.31 | 61.7358 | 23.776 |
| 0304 | 1089.31 | 33.0047 | 12.0236 |
| 0335 | 31.7769 | 5.6371 | 1.63412 |
| 1600 | 54.0763 | 7.35366 | 3.33375 |
| **Mean** | **4080.5** | **44.6195** | **17.0658** |
## Training Notes
Scale factors were computed as half the average per-image maximum over unique train-split timestamps:
```text
[58.224720422037755, 216.21549287451052, 1616.446579054541, 2551.0149615718674, 1190.0182024885178, 887.1800787601859, 112.33733897339224, 266.61844876445224]
```
The final logged training-batch metrics at step 7750 were:
```text
weighted MAE: 0.009133
mean raw RMSE: 25.278
per-wavelength raw RMSE: [2.535, 4.554, 62.817, 67.771, 28.201, 22.793, 2.572, 10.982]
```
|