import torch import torch.nn as nn class SimpleUpscaleModel(nn.Module): def __init__(self, scale_factor=2): """ A simple model for upscaling inputs using bilinear interpolation. Args: scale_factor (int): The factor by which to upscale the input. """ super(SimpleUpscaleModel, self).__init__() # Upsampling layer self.upsample = nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=True) def forward(self, x): """ Forward pass of the network. Args: x (torch.Tensor): Input tensor of shape (batch_size, channels, height, width). Returns: torch.Tensor: Upscaled output tensor. """ return self.upsample(x) if __name__ == "__main__": # Create the model scale_factor = 2 model = SimpleUpscaleModel(scale_factor=scale_factor) # Save the model model_path = "model_weights.pth" torch.save(model.state_dict(), model_path) print(f"Model saved to {model_path}")