| | import torch |
| | import torch.nn as nn |
| | import torch.optim as optim |
| | import coremltools as ct |
| |
|
| | from torch.utils.data import Dataset, DataLoader |
| | from torchvision import transforms |
| | import os |
| | from PIL import Image |
| | import torchvision.transforms.functional as TF |
| |
|
| | device = torch.device("mps") |
| | class UPSC(nn.Module): |
| | def __init__(self): |
| | super(UPSC,self).__init__() |
| | self.model = nn.Sequential( |
| | nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5, padding=2), |
| | nn.ReLU(), |
| | nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, padding=1), |
| | nn.ReLU(), |
| | |
| | nn.Conv2d(in_channels=32, out_channels=3 * 3 * 3, kernel_size=3, padding=1), |
| | |
| | nn.PixelShuffle(3) |
| | ) |
| | def forward(self, x): |
| | return self.model(x) |
| |
|
| | model = UPSC().to(device) |
| | model.load_state_dict(torch.load("upscaling.pth", weights_only=True)) |
| | model.eval() |
| |
|
| | img = Image.open("test.png").convert("RGB") |
| |
|
| | |
| | transform = transforms.Compose([ |
| | transforms.Resize((256, 256)), |
| | transforms.ToTensor() |
| | ]) |
| |
|
| | lr_tensor = transform(img).unsqueeze(0).to(device) |
| |
|
| | with torch.no_grad(): |
| | sr_tensor = model(lr_tensor) |
| | traced_model = torch.jit.trace(model, lr_tensor) |
| |
|
| |
|
| | |
| | sr_image = TF.to_pil_image(sr_tensor.squeeze(0).clamp(0, 1)) |
| | sr_image.save("upscaled_output_5.jpg") |
| |
|
| | mlmodel = ct.convert( |
| | traced_model, |
| | inputs=[ct.ImageType(name="input", shape=lr_tensor.shape)], |
| | compute_units=ct.ComputeUnit.ALL |
| | ) |
| |
|
| | mlmodel.save("upscaling.mlmodel") |
| |
|