import torch def load_train_model(args, dev): from src.models.Gatr_pf_e_noise import ExampleWrapper as GravnetModel model = GravnetModel.load_from_checkpoint( args.load_model_weights, args=args, dev=0, map_location=dev,strict=False) return model def load_test_model(args, dev): if args.load_model_weights is not None and (not args.correction): from src.models.Gatr_pf_e_noise import ExampleWrapper as GravnetModel model = GravnetModel.load_from_checkpoint( args.load_model_weights, args=args, dev=0, map_location=dev, strict=False ) if args.load_model_weights is not None and args.correction: from src.models.Gatr_pf_e_noise import ExampleWrapper as GravnetModel ckpt = torch.load(args.load_model_weights, map_location=dev) state_dict = ckpt["state_dict"] model = GravnetModel( args=args, dev=0) model.load_state_dict(state_dict, strict=False) model2 = GravnetModel.load_from_checkpoint(args.load_model_weights_clustering, args=args, dev=0, strict=False, map_location=torch.device("cuda:0")) model.gatr = model2.gatr model.ScaledGooeyBatchNorm2_1 = model2.ScaledGooeyBatchNorm2_1 model.clustering = model2.clustering model.beta = model2.beta model.eval() return model