Spaces:
Sleeping
Sleeping
| import torch | |
| def load_train_model(args, dev): | |
| from src.models.Gatr_pf_e_noise import ExampleWrapper as GravnetModel | |
| model = GravnetModel.load_from_checkpoint( | |
| args.load_model_weights, args=args, dev=0, map_location=dev,strict=False) | |
| return model | |
| def load_test_model(args, dev): | |
| if args.load_model_weights is not None and (not args.correction): | |
| from src.models.Gatr_pf_e_noise import ExampleWrapper as GravnetModel | |
| model = GravnetModel.load_from_checkpoint( | |
| args.load_model_weights, args=args, dev=0, map_location=dev, strict=False | |
| ) | |
| if args.load_model_weights is not None and args.correction: | |
| from src.models.Gatr_pf_e_noise import ExampleWrapper as GravnetModel | |
| ckpt = torch.load(args.load_model_weights, map_location=dev) | |
| state_dict = ckpt["state_dict"] | |
| model = GravnetModel( args=args, dev=0) | |
| model.load_state_dict(state_dict, strict=False) | |
| model2 = GravnetModel.load_from_checkpoint(args.load_model_weights_clustering, args=args, dev=0, strict=False, map_location=torch.device("cuda:0")) | |
| model.gatr = model2.gatr | |
| model.ScaledGooeyBatchNorm2_1 = model2.ScaledGooeyBatchNorm2_1 | |
| model.clustering = model2.clustering | |
| model.beta = model2.beta | |
| model.eval() | |
| return model | |