lschmidt commited on
Commit
ac12760
·
verified ·
1 Parent(s): a3e23a5

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +9 -7
README.md CHANGED
@@ -47,6 +47,8 @@ It is adapted for downscaling of **2-channel ERA5 data** (e.g., wind u and v com
47
  from super_image import EdsrModel, EdsrConfig
48
  from huggingface_hub import hf_hub_download
49
  import torch
 
 
50
 
51
  # load config
52
  config, _ = EdsrConfig.from_pretrained("lschmidt/edsr-dsc")
@@ -62,18 +64,18 @@ state_dict = torch.load(state_dict_path, map_location="cpu")
62
  model.load_state_dict(state_dict, strict=False)
63
 
64
  # create random input: must be a 4D tensor (B, C=2, H, W)
65
- inputs = torch.randn(1, 2, 64, 64) # replace with coarse wind velocity fields
66
 
67
  # or use sample data
68
- import xarray as xr
69
- import numpy as np
70
-
71
- data_path = hf_hub_download(repo_id="lschmidt/edsr-dsc/test_data", filename="test_wind_velocities.nc")
 
72
  ds = xr.open_dataset(data_path)
73
  u = ds["u100"].values[0]
74
  v = ds["v100"].values[0]
75
- inputs = torch.from_numpy(np.stack([u, v], axis=0)).unsqueeze(0).float() # shape (1, 2, H, W)
76
 
77
  # prediction
78
  outputs = model(inputs)
79
-
 
47
  from super_image import EdsrModel, EdsrConfig
48
  from huggingface_hub import hf_hub_download
49
  import torch
50
+ import xarray as xr
51
+ import numpy as np
52
 
53
  # load config
54
  config, _ = EdsrConfig.from_pretrained("lschmidt/edsr-dsc")
 
64
  model.load_state_dict(state_dict, strict=False)
65
 
66
  # create random input: must be a 4D tensor (B, C=2, H, W)
67
+ inputs = torch.randn(1, 2, 40, 40) # replace with coarse wind velocity fields
68
 
69
  # or use sample data
70
+ data_path = hf_hub_download(
71
+ repo_id="lschmidt/edsr-dsc",
72
+ filename="test_wind_velocities.nc",
73
+ subfolder="test_data"
74
+ )
75
  ds = xr.open_dataset(data_path)
76
  u = ds["u100"].values[0]
77
  v = ds["v100"].values[0]
78
+ inputs = torch.from_numpy(np.stack([u, v], axis=0)).unsqueeze(0).float()
79
 
80
  # prediction
81
  outputs = model(inputs)