File size: 567 Bytes
4b697f3 60da226 4b697f3 c123368 4b697f3 c123368 4b697f3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | # example.py
import torch
from model import SRResNet # Import your model class
# Load the pre-trained model
model = SRResNet(in_channels=12, out_channels=72, upscale=1)
model.load_state_dict(torch.load("best_netG.pth")) # Load model weights
# Create a random input tensor (e.g., for testing purposes)
input_tensor = torch.rand(1, 12, 128, 128) # Batch size 1, 3 channels, 128x128 image
# Perform inference
output_tensor = model(input_tensor)
# Print input and output shapes
print("Input shape:", input_tensor.shape)
print("Output shape:", output_tensor.shape)
|