lschmidt commited on
Commit
513c07c
·
verified ·
1 Parent(s): cafd2d6

Create rcan_model.py

Browse files
Files changed (1) hide show
  1. rcan_model.py +28 -0
rcan_model.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from super_image import RcanModel, RcanConfig
4
+
5
+ class CustomRcan(RcanModel):
6
+ """
7
+ RCAN variant without sub_mean / add_mean normalization.
8
+ Useful for physical variables like wind components (u, v),
9
+ where image normalization is not applicable.
10
+ """
11
+ def forward(self, x):
12
+ # Skip sub_mean and add_mean
13
+ x = self.head(x)
14
+ res = self.body(x)
15
+ res += x
16
+ x = self.tail(res)
17
+ return x
18
+
19
+ def load_rcan(pretrained_repo="lschmidt/rcan-dsc", config_file="config.json", weight_file="pytorch_model_4x.pt"):
20
+ from huggingface_hub import hf_hub_download
21
+
22
+ config, _ = RcanConfig.from_pretrained(pretrained_repo, config_filename=config_file)
23
+ model = CustomRcan(config)
24
+
25
+ state_dict_path = hf_hub_download(repo_id=pretrained_repo, filename=weight_file)
26
+ state_dict = torch.load(state_dict_path, map_location="cpu")
27
+ model.load_state_dict(state_dict, strict=False)
28
+ return model