# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. # SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import datetime import glob import json import logging from urllib.parse import urlparse import numpy as np import torch import xarray import physicsnemo # noqa: F401 for docs from physicsnemo.models.dlwp import DLWP from physicsnemo.models.graphcast.graph_cast_net import GraphCastNet from physicsnemo.utils import filesystem from physicsnemo.utils.zenith_angle import cos_zenith_angle logger = logging.getLogger(__name__) # class _DummyModule(torch.nn.Module): # """Hack to handle that checkpoint parameter names begin with "module." """ # def __init__(self, model): # super().__init__() # self.module = model class _CosZenWrapper(torch.nn.Module): def __init__(self, model, lon, lat): super().__init__() self.model = model self.lon = lon self.lat = lat def forward(self, x, time): lon_grid, lat_grid = np.meshgrid(self.lon, self.lat) cosz = cos_zenith_angle(time, lon_grid, lat_grid) cosz = cosz.astype(np.float32) z = torch.from_numpy(cosz).to(device=x.device) x, z = torch.broadcast_tensors(x, z) x = torch.cat([x, z], dim=1) return self.model(x) # def sfno(package: filesystem.Package, pretrained: bool = True) -> torch.nn.Module: # """Load SFNO model from checkpoints trained with era5_wind""" # path = package.get("config.json") # params = ParamsBase.from_json(path) # model = sfnonet.SphericalFourierNeuralOperatorNet(params) # logger.info(str(params.to_dict())) # if pretrained: # weights = package.get("weights.tar") # checkpoint = torch.load(weights) # load_me = _DummyModule(model) # state = checkpoint["model_state"] # state = {"module.device_buffer": model.device_buffer, **state} # load_me.load_state_dict(state) # if params.add_zenith: # nlat = params.img_shape_x # nlon = params.img_shape_y # lat = 90 - np.arange(nlat) * 0.25 # lon = np.arange(nlon) * 0.25 # model = _CosZenWrapper(model, lon, lat) # return model class _GraphCastWrapper(torch.nn.Module): def __init__(self, model, dtype): super().__init__() self.model = model self.dtype = dtype def forward(self, x): x = x.to(self.dtype) y = self.model(x) return y def graphcast_34ch( package: filesystem.Package, pretrained: bool = True ) -> torch.nn.Module: """Load Graphcast 34 channel model from a checkpoint""" num_channels = 34 icospheres_path = package.get("icospheres.json") static_data_path = package.get("static", recursive=True) # instantiate the model, set dtype base_model = ( GraphCastNet( meshgraph_path=icospheres_path, static_dataset_path=static_data_path, input_dim_grid_nodes=num_channels, input_dim_mesh_nodes=3, input_dim_edges=4, output_dim_grid_nodes=num_channels, processor_layers=16, hidden_dim=512, do_concat_trick=True, ) .to(dtype=torch.bfloat16) .to("cuda") # TODO hardcoded ) # set model to inference mode base_model.eval() model = _GraphCastWrapper(base_model, torch.bfloat16) if pretrained: path = package.get("weights.tar") checkpoint = torch.load(path) weights = checkpoint["model_state_dict"] weights = _fix_state_dict_keys(weights, add_module=False) model.model.load_state_dict(weights, strict=True) return model class _DLWPWrapper(torch.nn.Module): def __init__( self, model, lsm, longrid, latgrid, topographic_height, ll_to_cs_mapfile_path, cs_to_ll_mapfile_path, ): super(_DLWPWrapper, self).__init__() self.model = model self.lsm = lsm self.longrid = longrid self.latgrid = latgrid self.topographic_height = topographic_height # load map weights # Note: these map files are created using TempestRemap library # https://github.com/ClimateGlobalChange/tempestremap # To generate the maps, the below sequence of commands can be # executed once TempestRemap is installed. # GenerateRLLMesh --lat 721 --lon 1440 --file out_latlon.g --lat_begin 90 --lat_end -90 --out_format Netcdf4 # GenerateCSMesh --res --file out_cubedsphere.g --out_format Netcdf4 # GenerateOverlapMesh --a out_latlon.g --b out_cubedsphere.g --out overlap_latlon_cubedsphere.g --out_format Netcdf4 # GenerateOfflineMap --in_mesh out_latlon.g --out_mesh out_cubedsphere.g --ov_mesh overlap_latlon_cubedsphere.g --in_np 1 --in_type FV --out_type FV --out_map map_LL_CS.nc --out_format Netcdf4 # GenerateOverlapMesh --a out_cubedsphere.g --b out_latlon.g --out overlap_cubedsphere_latlon.g --out_format Netcdf4 # GenerateOfflineMap --in_mesh out_cubedsphere.g --out_mesh out_latlon.g --ov_mesh overlap_cubedsphere_latlon.g --in_np 1 --in_type FV --out_type FV --out_map map_CS_LL.nc --out_format Netcdf4 self.input_map_wts = xarray.open_dataset(ll_to_cs_mapfile_path) self.output_map_wts = xarray.open_dataset(cs_to_ll_mapfile_path) def prepare_input(self, input, time): device = input.device dtype = input.dtype i = self.input_map_wts.row.values - 1 j = self.input_map_wts.col.values - 1 data = self.input_map_wts.S.values M = torch.sparse_coo_tensor(np.array((i, j)), data).type(dtype).to(device) N, T, C = input.shape[0], input.shape[1], input.shape[2] input = (M @ input.reshape(N * T * C, -1).T).T S = int((M.shape[0] / 6) ** 0.5) input = input.reshape(N, T, C, 6, S, S) input_list = list(torch.split(input, 1, dim=1)) input_list = [tensor.squeeze(1) for tensor in input_list] repeat_vals = (input.shape[0], -1, -1, -1, -1) # repeat along batch dimension for i in range(len(input_list)): tisr = np.maximum( cos_zenith_angle( time - datetime.timedelta(hours=6 * (input.shape[1] - 1)) + datetime.timedelta(hours=6 * i), self.longrid, self.latgrid, ), 0, ) - (1 / np.pi) # subtract mean value tisr = ( torch.tensor(tisr, dtype=dtype) .to(device) .unsqueeze(dim=0) .unsqueeze(dim=0) ) # add channel and batch size dimension tisr = tisr.expand(*repeat_vals) # TODO - find better way to batch TISR input_list[i] = torch.cat( (input_list[i], tisr), dim=1 ) # concat along channel dim input_model = torch.cat( input_list, dim=1 ) # concat the time dimension into channels lsm_tensor = torch.tensor(self.lsm, dtype=dtype).to(device).unsqueeze(dim=0) lsm_tensor = lsm_tensor.expand(*repeat_vals) topographic_height_tensor = ( torch.tensor((self.topographic_height - 3.724e03) / 8.349e03, dtype=dtype) .to(device) .unsqueeze(dim=0) ) topographic_height_tensor = topographic_height_tensor.expand(*repeat_vals) input_model = torch.cat( (input_model, lsm_tensor, topographic_height_tensor), dim=1 ) return input_model def prepare_output(self, output): device = output.device dtype = output.dtype output = torch.split(output, output.shape[1] // 2, dim=1) output = torch.stack(output, dim=1) # add time dimension back in i = self.output_map_wts.row.values - 1 j = self.output_map_wts.col.values - 1 data = self.output_map_wts.S.values M = torch.sparse_coo_tensor(np.array((i, j)), data).type(dtype).to(device) N, T, C = output.shape[0], 2, output.shape[2] output = (M @ output.reshape(N * T * C, -1).T).T output = output.reshape(N, T, C, 721, 1440) return output def forward(self, x, time): x = self.prepare_input(x, time) y = self.model(x) return self.prepare_output(y) def dlwp(package, pretrained=True): # load static datasets lsm = xarray.open_dataset(package.get("land_sea_mask_rs_cs.nc"))["lsm"].values topographic_height = xarray.open_dataset(package.get("geopotential_rs_cs.nc"))[ "z" ].values latlon_grids = xarray.open_dataset(package.get("latlon_grid_field_rs_cs.nc")) latgrid, longrid = latlon_grids["latgrid"].values, latlon_grids["longrid"].values # load maps parsed_uri = urlparse(package.root) if parsed_uri.scheme == "file": root_path = parsed_uri.path else: root_path = package.root ll_to_cs_file = glob.glob(root_path + package.seperator + "map_LL*_CS*.nc") cs_to_ll_file = glob.glob(root_path + package.seperator + "map_CS*_LL*.nc") if ll_to_cs_file: file_path = ll_to_cs_file[0] # take the first match if parsed_uri.scheme == "file": ll_to_cs_relative_path = file_path[len(root_path) :].lstrip( package.seperator ) else: ll_to_cs_relative_path = file_path[len(root_path) :] if cs_to_ll_file: file_path = cs_to_ll_file[0] if parsed_uri.scheme == "file": cs_to_ll_relative_path = file_path[len(root_path) :].lstrip( package.seperator ) else: cs_to_ll_relative_path = file_path[len(root_path) :] ll_to_cs_mapfile_path = package.get(ll_to_cs_relative_path) cs_to_ll_mapfile_path = package.get(cs_to_ll_relative_path) with open(package.get("config.json")) as json_file: config = json.load(json_file) core_model = DLWP( nr_input_channels=config["nr_input_channels"], nr_output_channels=config["nr_output_channels"], ) if pretrained: weights_path = package.get("weights.pt") weights = torch.load(weights_path) fixed_weights = _fix_state_dict_keys(weights, add_module=False) core_model.load_state_dict(fixed_weights) model = _DLWPWrapper( core_model, lsm, longrid, latgrid, topographic_height, ll_to_cs_mapfile_path, cs_to_ll_mapfile_path, ) model.eval() return model def _fix_state_dict_keys(state_dict, add_module=False): """Add or remove 'module.' from state_dict keys Parameters ---------- state_dict : Dict Model state_dict add_module : bool, optional If True, will add 'module.' to keys, by default False Returns ------- Dict Model state_dict with fixed keys """ fixed_state_dict = {} for key, value in state_dict.items(): if add_module: new_key = "module." + key else: new_key = key.replace("module.", "") fixed_state_dict[new_key] = value return fixed_state_dict