Spaces:
Runtime error
Runtime error
| import argparse | |
| import numpy as np | |
| import pickle | |
| import os | |
| import yaml | |
| import torch | |
| import torch.nn as nn | |
| from models import UNetEncoder, Decoder | |
| def load_training_data( | |
| path: str, | |
| standardize_weather: bool = False, | |
| standardize_so4: bool = False, | |
| log_so4: bool = False, | |
| remove_zeros: bool = True, | |
| return_pp_data: bool = False, | |
| year_averages: bool = False, | |
| ): | |
| with open(path, "rb") as io: | |
| data = pickle.load(io) | |
| C = data["covars_rast"] # [:, weather_cols] | |
| names = data["covars_names"] | |
| if standardize_weather: | |
| C -= C.mean((0, 2, 3), keepdims=True) | |
| C /= C.std((0, 2, 3), keepdims=True) | |
| if year_averages: | |
| Cyearly_average = np.zeros_like(C) | |
| for t in range(C.shape[0]): | |
| if t < 12: | |
| Cyearly_average[t] = np.mean(C[:12], 0) | |
| else: | |
| Cyearly_average[t] = np.mean(C[(t - 12) : t], 0) | |
| C = np.concatenate([C, Cyearly_average], 1) | |
| names = names + [x + ".yavg" for x in names] | |
| names = [x.replace(".", "_") for x in names] | |
| Y = data["so4_rast"] | |
| M = data["so4_mask"] | |
| M[92:, 185:] = 0.0 # annoying weird corner | |
| M[80:, :60] = 0.0 # annoying weird corner | |
| if remove_zeros: | |
| M = (Y > 0) * M | |
| M = M * np.prod(M, 0) | |
| else: | |
| M = np.stack([M] * Y.shape[0]) | |
| if log_so4: | |
| # Y = np.log(M * Y + 1e-8) | |
| Y = np.log(M * Y + 1.0) | |
| if standardize_so4: | |
| ix = np.where(M) | |
| Y -= Y[ix].mean() | |
| Y /= Y[ix].std() | |
| if not return_pp_data: | |
| return C, names, Y, M | |
| else: | |
| return C, names, Y, M, data["pp_locs"] | |
| def radius_from_dir(s: str, prefix: str): | |
| return int(s.split("/")[-1].split("_")[0].replace(prefix, "")) | |
| def load_models(dirs: dict, prefix="h", nd=5): | |
| D = {} | |
| for name, datadir in dirs.items(): | |
| radius = radius_from_dir(datadir, prefix) | |
| args = argparse.Namespace() | |
| with open(os.path.join(datadir, "args.yaml"), "r") as io: | |
| for k, v in yaml.load(io, Loader=yaml.FullLoader).items(): | |
| setattr(args, k, v) | |
| if k == "nbrs_av": | |
| setattr(args, "av_nbrs", v) | |
| elif k == "av_nbrs": | |
| setattr(args, "nbrs_av", v) | |
| bn_type = "frn" if not hasattr(args, "bn_type") else args.bn_type | |
| mkw = dict( | |
| n_hidden=args.nhidden, | |
| depth=args.depth, | |
| num_res=args.nres, | |
| ksize=args.ksize, | |
| groups=args.groups, | |
| batchnorm=True, | |
| batchnorm_type=bn_type, | |
| ) | |
| dkw = dict(batchnorm=True, offset=True, batchnorm_type=bn_type) | |
| dev = "cuda" if torch.cuda.is_available() else "cpu" | |
| if not args.local and args.nbrs_av == 0: | |
| enc = UNetEncoder(nd, args.nhidden, **mkw) | |
| dec = Decoder(args.nhidden, nd, args.nhidden, **dkw) | |
| else: | |
| enc = nn.Identity() | |
| dec = Decoder(nd, nd, args.nhidden, **dkw) | |
| mod = nn.ModuleDict({"enc": enc, "dec": dec}) | |
| objs = dict( | |
| mod=mod, | |
| args=args, | |
| radius=radius, | |
| nbrs_av=args.nbrs_av, | |
| local=args.local, | |
| ) | |
| mod.eval() | |
| for p in mod.parameters(): | |
| p.requires_grad = False | |
| weights_path = os.path.join(datadir, "model.pt") | |
| state_dict = torch.load(weights_path, map_location=torch.device("cpu")) | |
| mod.load_state_dict(state_dict) | |
| D[datadir] = objs | |
| return D | |