|
|
import torch |
|
|
import torch.nn as nn |
|
|
from super_image import RcanModel, RcanConfig |
|
|
|
|
|
class CustomRcan(RcanModel): |
|
|
""" |
|
|
RCAN variant without sub_mean / add_mean normalization. |
|
|
Useful for physical variables like wind components (u, v), |
|
|
where image normalization is not applicable. |
|
|
""" |
|
|
def forward(self, x): |
|
|
|
|
|
x = self.head(x) |
|
|
res = self.body(x) |
|
|
res += x |
|
|
x = self.tail(res) |
|
|
return x |
|
|
|
|
|
def load_rcan(pretrained_repo="lschmidt/rcan-dsc", config_file="config.json", weight_file="pytorch_model_4x.pt"): |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
config, _ = RcanConfig.from_pretrained(pretrained_repo, config_filename=config_file) |
|
|
model = CustomRcan(config) |
|
|
|
|
|
state_dict_path = hf_hub_download(repo_id=pretrained_repo, filename=weight_file) |
|
|
state_dict = torch.load(state_dict_path, map_location="cpu") |
|
|
model.load_state_dict(state_dict, strict=False) |
|
|
return model |
|
|
|