| # example.py | |
| import torch | |
| from model import SRResNet # Import your model class | |
| # Load the pre-trained model | |
| model = SRResNet(in_channels=12, out_channels=72, upscale=1) | |
| model.load_state_dict(torch.load("best_netG.pth")) # Load model weights | |
| # Create a random input tensor (e.g., for testing purposes) | |
| input_tensor = torch.rand(1, 12, 128, 128) # Batch size 1, 3 channels, 128x128 image | |
| # Perform inference | |
| output_tensor = model(input_tensor) | |
| # Print input and output shapes | |
| print("Input shape:", input_tensor.shape) | |
| print("Output shape:", output_tensor.shape) | |