|
|
--- |
|
|
license: mit |
|
|
tags: |
|
|
- downscaling |
|
|
- edsr |
|
|
- ERA5 - COSMO-REA6 |
|
|
- wind |
|
|
- super-image |
|
|
library_name: super-image |
|
|
model_type: edsr |
|
|
datasets: |
|
|
- your-dataset-name |
|
|
--- |
|
|
|
|
|
# EDSR-DSC (4Γ Downscaling of Wind Velocities) |
|
|
|
|
|
This model is a custom-trained version of the Enhanced Deep Super-Resolution (EDSR) model from the [`super-image`](https://github.com/eugenesiow/super-image) library. |
|
|
It is adapted for downscaling of **2-channel ERA5 data** (e.g., wind u and v components), by a factor of 4Γ (trained using **COSMO-REA6** as high-resolution data). |
|
|
|
|
|
--- |
|
|
|
|
|
## π§ Model Architecture |
|
|
|
|
|
- **Base**: EDSR ([Lim et al. 2017](https://arxiv.org/abs/1707.02921)) |
|
|
- **Input channels**: 2 (U & V components of wind speed) |
|
|
- **Output channels**: 2 |
|
|
- **Feature channels (`n_feats`)**: 64 |
|
|
- **Residual blocks**: 32 |
|
|
- **Mean-shift normalization**: Removed |
|
|
- **Upsampling**: Enabled |
|
|
- **Scale factor**: 4Γ |
|
|
|
|
|
--- |
|
|
|
|
|
## π¦ Files in this Repository |
|
|
|
|
|
| File | Description | |
|
|
|------------------------|-----------------------------------------------| |
|
|
| `config.json` | Configuration for the modified EDSR model | |
|
|
| `pytorch_model_4x.pt` | Pretrained weights for 4Γ upscaling | |
|
|
|
|
|
--- |
|
|
|
|
|
## π How to Use |
|
|
|
|
|
```python |
|
|
from super_image import EdsrModel, EdsrConfig |
|
|
from huggingface_hub import hf_hub_download |
|
|
import torch |
|
|
import xarray as xr |
|
|
import numpy as np |
|
|
|
|
|
# load config |
|
|
config, _ = EdsrConfig.from_pretrained("lschmidt/edsr-dsc") |
|
|
|
|
|
# load model & remove normalization |
|
|
model = EdsrModel(config) |
|
|
del model.sub_mean |
|
|
del model.add_mean |
|
|
|
|
|
# load pre-trained weights |
|
|
state_dict_path = hf_hub_download(repo_id="lschmidt/edsr-dsc", filename="pytorch_model_4x.pt") |
|
|
state_dict = torch.load(state_dict_path, map_location="cpu") |
|
|
model.load_state_dict(state_dict, strict=False) |
|
|
|
|
|
# create random input: must be a 4D tensor (B, C=2, H, W) |
|
|
inputs = torch.randn(1, 2, 40, 40) # replace with coarse wind velocity fields |
|
|
|
|
|
# or use sample data |
|
|
data_path = hf_hub_download( |
|
|
repo_id="lschmidt/edsr-dsc", |
|
|
filename="test_wind_velocities.nc", |
|
|
subfolder="test_data" |
|
|
) |
|
|
ds = xr.open_dataset(data_path) |
|
|
u = ds["u100"].values[0] |
|
|
v = ds["v100"].values[0] |
|
|
inputs = torch.from_numpy(np.stack([u, v], axis=0)).unsqueeze(0).float() |
|
|
|
|
|
# prediction |
|
|
outputs = model(inputs) |
|
|
|