| import torch |
| import torch.nn as nn |
|
|
| class SimpleUpscaleModel(nn.Module): |
| def __init__(self, scale_factor=2): |
| """ |
| A simple model for upscaling inputs using bilinear interpolation. |
| |
| Args: |
| scale_factor (int): The factor by which to upscale the input. |
| """ |
| super(SimpleUpscaleModel, self).__init__() |
| |
| |
| self.upsample = nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=True) |
| |
| def forward(self, x): |
| """ |
| Forward pass of the network. |
| |
| Args: |
| x (torch.Tensor): Input tensor of shape (batch_size, channels, height, width). |
| |
| Returns: |
| torch.Tensor: Upscaled output tensor. |
| """ |
| return self.upsample(x) |
|
|
| if __name__ == "__main__": |
| |
| scale_factor = 2 |
| model = SimpleUpscaleModel(scale_factor=scale_factor) |
|
|
| |
| model_path = "model_weights.pth" |
| torch.save(model.state_dict(), model_path) |
|
|
| print(f"Model saved to {model_path}") |