lschmidt commited on
Commit
f4016be
·
verified ·
1 Parent(s): b57088c

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +14 -3
README.md CHANGED
@@ -44,11 +44,22 @@ It is adapted for super-resolution of **2-channel weather data** (e.g., wind u a
44
  ## 🚀 How to Use
45
 
46
  ```python
47
- from super_image import EdsrModel
 
48
  import torch
49
 
50
- # Load model and weights directly from Hugging Face Hub
51
- model = EdsrModel.from_pretrained("lschmidt/edsr-dsc", scale=4)
 
 
 
 
 
 
 
 
 
 
52
 
53
  # Prepare input: must be a 4D tensor (B, C=2, H, W)
54
  inputs = torch.randn(1, 2, 64, 64) # Replace with actual wind field data
 
44
  ## 🚀 How to Use
45
 
46
  ```python
47
+ from super_image import EdsrModel, EdsrConfig
48
+ from huggingface_hub import hf_hub_download
49
  import torch
50
 
51
+ # load config
52
+ config, _ = EdsrConfig.from_pretrained("lschmidt/edsr-dsc")
53
+
54
+ # load & modify model
55
+ model = EdsrModel(config)
56
+ del model.sub_mean
57
+ del model.add_mean
58
+
59
+ # load pre-trained weights
60
+ state_dict_path = hf_hub_download(repo_id="lschmidt/edsr-dsc", filename="pytorch_model_4x.pt")
61
+ state_dict = torch.load(state_dict_path, map_location="cpu")
62
+ model.load_state_dict(state_dict, strict=False)
63
 
64
  # Prepare input: must be a 4D tensor (B, C=2, H, W)
65
  inputs = torch.randn(1, 2, 64, 64) # Replace with actual wind field data