# example.py import torch from model import SRResNet # Import your model class # Load the pre-trained model model = SRResNet(in_channels=12, out_channels=72, upscale=1) model.load_state_dict(torch.load("best_netG.pth")) # Load model weights # Create a random input tensor (e.g., for testing purposes) input_tensor = torch.rand(1, 12, 128, 128) # Batch size 1, 3 channels, 128x128 image # Perform inference output_tensor = model(input_tensor) # Print input and output shapes print("Input shape:", input_tensor.shape) print("Output shape:", output_tensor.shape)