import torch from model import * # import extracted model classes def load_model(weights_path="model.pt", device=None): if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" # Replace 'YourModelClass' with the actual class name from model.py model = IcosahedralRRF() state = torch.load(weights_path, map_location=device) model.load_state_dict(state) model.to(device) model.eval() return model if __name__ == "__main__": model = load_model() # Example dummy input - adjust to your model's expected input x = torch.randn(1, 3, 224, 224) y = model(x) print("Output shape:", y.shape)