| import torch | |
| from model import * # import extracted model classes | |
| def load_model(weights_path="model.pt", device=None): | |
| if device is None: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Replace 'YourModelClass' with the actual class name from model.py | |
| model = IcosahedralRRF() | |
| state = torch.load(weights_path, map_location=device) | |
| model.load_state_dict(state) | |
| model.to(device) | |
| model.eval() | |
| return model | |
| if __name__ == "__main__": | |
| model = load_model() | |
| # Example dummy input - adjust to your model's expected input | |
| x = torch.randn(1, 3, 224, 224) | |
| y = model(x) | |
| print("Output shape:", y.shape) | |