rcan-dsc / rcan_model.py
lschmidt's picture
Create rcan_model.py
513c07c verified
raw
history blame contribute delete
979 Bytes
import torch
import torch.nn as nn
from super_image import RcanModel, RcanConfig
class CustomRcan(RcanModel):
"""
RCAN variant without sub_mean / add_mean normalization.
Useful for physical variables like wind components (u, v),
where image normalization is not applicable.
"""
def forward(self, x):
# Skip sub_mean and add_mean
x = self.head(x)
res = self.body(x)
res += x
x = self.tail(res)
return x
def load_rcan(pretrained_repo="lschmidt/rcan-dsc", config_file="config.json", weight_file="pytorch_model_4x.pt"):
from huggingface_hub import hf_hub_download
config, _ = RcanConfig.from_pretrained(pretrained_repo, config_filename=config_file)
model = CustomRcan(config)
state_dict_path = hf_hub_download(repo_id=pretrained_repo, filename=weight_file)
state_dict = torch.load(state_dict_path, map_location="cpu")
model.load_state_dict(state_dict, strict=False)
return model