MeteoGAN / example.py
manmeet3591's picture
Update example.py
c123368 verified
# example.py
import torch
from model import SRResNet # Import your model class
# Load the pre-trained model
model = SRResNet(in_channels=12, out_channels=72, upscale=1)
model.load_state_dict(torch.load("best_netG.pth")) # Load model weights
# Create a random input tensor (e.g., for testing purposes)
input_tensor = torch.rand(1, 12, 128, 128) # Batch size 1, 3 channels, 128x128 image
# Perform inference
output_tensor = model(input_tensor)
# Print input and output shapes
print("Input shape:", input_tensor.shape)
print("Output shape:", output_tensor.shape)