Update README.md
Browse files
README.md
CHANGED
|
@@ -46,5 +46,16 @@ model.load_state_dict(state_dict, strict=False)
|
|
| 46 |
# generate sample data (B, C, W, H)
|
| 47 |
inputs = torch.randn(1, 2, 10, 10)
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
# prediction
|
| 50 |
output = model(inputs)
|
|
|
|
| 46 |
# generate sample data (B, C, W, H)
|
| 47 |
inputs = torch.randn(1, 2, 10, 10)
|
| 48 |
|
| 49 |
+
# or use test data
|
| 50 |
+
data_path = hf_hub_download(
|
| 51 |
+
repo_id="lschmidt/rcan-dsc",
|
| 52 |
+
filename="test_wind_velocities.nc",
|
| 53 |
+
subfolder="test_data"
|
| 54 |
+
)
|
| 55 |
+
ds = xr.open_dataset(data_path)
|
| 56 |
+
u = ds["u100"].values[0]
|
| 57 |
+
v = ds["v100"].values[0]
|
| 58 |
+
inputs = torch.from_numpy(np.stack([u, v], axis=0)).unsqueeze(0).float()
|
| 59 |
+
|
| 60 |
# prediction
|
| 61 |
output = model(inputs)
|