| --- |
| license: mit |
| library_name: pytorch |
| pipeline_tag: image-to-image |
| tags: |
| - solar-physics |
| - solar-forecasting |
| - swin-transformer |
| - pytorch |
| - aia |
| - patch-size-8 |
| datasets: |
| - hrrsmjd/AIA_12hour_512x512 |
| --- |
| |
| # Solaris Small Patch 8 |
|
|
| This repository contains a Solaris-Small checkpoint trained for 12-hour multi-wavelength solar forecasting, following the Solaris pretraining setup from [Solaris: A Foundation Model of the Sun](https://arxiv.org/abs/2411.16339). |
|
|
| This checkpoint supersedes the earlier patch-size-8 upload. It was trained with a chronological 80/10/10 split, a 24-hour guard band at split boundaries, AdamW weight decay `0.05`, and no EMA weights. |
|
|
| The checkpoint was trained on `hrrsmjd/AIA_12hour_512x512` for 7750 optimizer steps using two history frames (`t-12h`, `t`) to predict all eight pretraining wavelengths at `t+12h`. |
|
|
| ## Files |
|
|
| - `solaris_small_patch8_model_state_dict.pt`: reusable PyTorch checkpoint containing `model_state_dict`, learned normalization coefficients, wavelengths, scale factors, patch size, seed, split metadata, test metrics, training step, and final training loss. |
| - `config.json`: lightweight metadata for reconstructing the model and normalization. |
| - `assets/solaris_small_patch8_test0_prediction.png`: example qualitative test prediction plot. |
| - `eval/solaris_pretrain_p8_chronosplit_wd005_noema_test_mse_subset_0941.md`: full chronological test-split raw-scale MSE/RMSE/MAE report. |
|
|
| ## Example Plot |
|
|
| The plot below shows one test sample with rows for `input t-12h`, `input t`, `target t+12h`, `prediction t+12h`, and `prediction - target` across all eight wavelengths. |
|
|
|  |
|
|
| ## Model Details |
|
|
| - Architecture: `SolarisSmall` |
| - Patch size: `8` |
| - Embedding dimension: `256` |
| - Encoder depths: `(2, 6, 2)` |
| - Decoder depths: `(2, 6, 2)` |
| - Output wavelengths: `0094`, `0131`, `0171`, `0193`, `0211`, `0304`, `0335`, `1600` |
| - Training dataset: `hrrsmjd/AIA_12hour_512x512` |
| - Training target: 12-hour forecast |
| - Split scheme: chronological 80/10/10 over valid samples with a 24-hour boundary guard band |
| - Split counts: train `7541`, validation `939`, test `941` |
| - Training budget: 7750 optimizer steps, batch size 8, gradient accumulation 4 |
| - Optimizer: AdamW, `lr=5e-4`, cosine decay to `5e-5`, `weight_decay=0.05`, betas `(0.9, 0.95)` |
| - EMA: disabled |
| - Seed: `42` |
|
|
| ## Loading |
|
|
| ```python |
| import torch |
| |
| from solaris.model.solaris import SolarisSmall |
| |
| checkpoint = torch.load("solaris_small_patch8_model_state_dict.pt", map_location="cpu", weights_only=False) |
| |
| model = SolarisSmall( |
| out_levels=len(checkpoint["wavelengths"]), |
| patch_size=checkpoint["patch_size"], |
| ) |
| model.load_state_dict(checkpoint["model_state_dict"]) |
| model.eval() |
| |
| scale_factors = torch.tensor(checkpoint["scale_factors"], dtype=torch.float32) |
| norm_coeff_1 = checkpoint["norm_coeff_1"] |
| norm_coeff_2 = checkpoint["norm_coeff_2"] |
| ``` |
|
|
| Inputs should be normalized with the Solaris transform used during training. Model outputs are normalized intensities; multiply by the per-wavelength scale factors before comparing to raw-intensity targets. |
|
|
| ## Test Metrics |
|
|
| Metrics below use all 941 chronological test samples and are computed on the raw intensity scale. |
|
|
| | Wavelength (A) | MSE | RMSE | MAE | |
| |---:|---:|---:|---:| |
| | 0094 | 7.08928 | 2.66257 | 0.310692 | |
| | 0131 | 147.528 | 12.1461 | 1.44793 | |
| | 0171 | 13730.9 | 117.179 | 52.9684 | |
| | 0193 | 22906.6 | 151.349 | 62.3088 | |
| | 0211 | 7752.18 | 88.0464 | 35.0413 | |
| | 0304 | 1556.87 | 39.4572 | 14.2213 | |
| | 0335 | 48.0333 | 6.9306 | 2.0811 | |
| | 1600 | 145.126 | 12.0468 | 7.3929 | |
| | **Mean** | **5786.78** | **53.7272** | **21.9716** | |
|
|
| ## Training Notes |
|
|
| Scale factors were computed as half the average per-image maximum over unique train-split timestamps: |
|
|
| ```text |
| [57.944149103084534, 214.99738922760267, 1590.2998402078304, 2397.489401917806, 1080.261734048243, 830.778793198845, 104.45557294825853, 274.65685334356664] |
| ``` |
|
|
| The final logged training-batch metrics at step 7750 were: |
|
|
| ```text |
| weighted MAE: 0.007597 |
| mean raw RMSE: 19.576 |
| per-wavelength raw RMSE: [1.210, 4.034, 49.249, 53.399, 20.770, 17.612, 1.653, 8.684] |
| ``` |
|
|