File size: 567 Bytes
4b697f3
 
60da226
4b697f3
 
c123368
4b697f3
 
 
c123368
4b697f3
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# example.py
import torch
from model import SRResNet  # Import your model class

# Load the pre-trained model
model = SRResNet(in_channels=12, out_channels=72, upscale=1)
model.load_state_dict(torch.load("best_netG.pth"))  # Load model weights

# Create a random input tensor (e.g., for testing purposes)
input_tensor = torch.rand(1, 12, 128, 128)  # Batch size 1, 3 channels, 128x128 image

# Perform inference
output_tensor = model(input_tensor)

# Print input and output shapes
print("Input shape:", input_tensor.shape)
print("Output shape:", output_tensor.shape)