lschmidt commited on
Commit
96d6736
·
verified ·
1 Parent(s): 5f1b677

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +12 -1
README.md CHANGED
@@ -61,9 +61,20 @@ state_dict_path = hf_hub_download(repo_id="lschmidt/edsr-dsc", filename="pytorch
61
  state_dict = torch.load(state_dict_path, map_location="cpu")
62
  model.load_state_dict(state_dict, strict=False)
63
 
64
- # sample input: must be a 4D tensor (B, C=2, H, W)
65
  inputs = torch.randn(1, 2, 64, 64) # replace with coarse wind velocity fields
66
 
 
 
 
 
 
 
 
 
 
 
 
67
  # prediction
68
  outputs = model(inputs)
69
 
 
61
  state_dict = torch.load(state_dict_path, map_location="cpu")
62
  model.load_state_dict(state_dict, strict=False)
63
 
64
+ # create random input: must be a 4D tensor (B, C=2, H, W)
65
  inputs = torch.randn(1, 2, 64, 64) # replace with coarse wind velocity fields
66
 
67
+ # or use sample data
68
+ import xarray as xr
69
+ import numpy as np
70
+
71
+ data_path = hf_hub_download(repo_id="lschmidt/edsr-dsc/test_data", filename="test_wind_velocities.nc")
72
+ ds = xr.open_dataset(data_path)
73
+ u = ds["u100"].values[0]
74
+ v = ds["v100"].values[0]
75
+ inputs = torch.from_numpy(np.stack([u, v], axis=0)).unsqueeze(0).float() # shape (1, 2, H, W)
76
+
77
+
78
  # prediction
79
  outputs = model(inputs)
80