Converted to onnx using following code:

import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import torchvision.transforms as transforms

Define the model architecture (based on the code you provided)

class WatermarkRemover(nn.Module):
def init(self):
super(WatermarkRemover, self).init()
self.enc1 = self.conv_block(3, 64)
self.enc2 = self.conv_block(64, 128)
self.enc3 = self.conv_block(128, 256)
self.enc4 = self.conv_block(256, 512)

    self.bottleneck = self.conv_block(512, 1024)

    self.dec4 = self.conv_block(1024 + 512, 512)
    self.dec3 = self.conv_block(512 + 256, 256)
    # Corrected input channels based on error message
    self.dec2 = self.conv_block(384, 128)
    # Corrected input channels based on error message
    self.dec1 = self.conv_block(192, 64)

    self.final_layer = nn.Conv2d(64, 3, kernel_size=1)

def conv_block(self, in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
        nn.ReLU(inplace=True),
    )

def forward(self, x):
    e1 = self.enc1(x)
    e2 = self.enc2(nn.MaxPool2d(2)(e1))
    e3 = self.enc3(nn.MaxPool2d(2)(e2))
    e4 = self.enc4(nn.MaxPool2d(2)(e3))

    b = self.bottleneck(nn.MaxPool2d(2)(e4))

    # Use interpolate to explicitly set the size for upsampling
    d4 = self.dec4(torch.cat((F.interpolate(b, size=e4.size()[2:], mode='bilinear', align_corners=False), e4), dim=1))
    d3 = self.dec3(torch.cat((F.interpolate(d4, size=e3.size()[2:], mode='bilinear', align_corners=False), e3), dim=1))
    d2 = self.dec2(torch.cat((F.interpolate(d3, size=e2.size()[2:], mode='bilinear', align_corners=False), e2), dim=1))
    d1 = self.dec1(torch.cat((F.interpolate(d2, size=e1.size()[2:], mode='bilinear', align_corners=False), e1), dim=1))


    return self.final_layer(d1)

Load the PyTorch model

model = WatermarkRemover()
model.load_state_dict(torch.load('/content/model.pth', map_location=torch.device('cpu')))
model.eval()

Load and preprocess the sample input image

input_image_path = '/content/images.jpeg'
input_image = Image.open(input_image_path).convert('RGB')
preprocess = transforms.Compose([
transforms.ToTensor(),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # Add a batch dimension

Define the output path for the ONNX model

output_onnx_path = '/content/model.onnx'

Export the model to ONNX format

torch.onnx.export(model, # model being run
input_batch, # sample input
output_onnx_path, # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=11, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names = ['input'], # the name of the input node(s)
output_names = ['output'], # the name of the output node(s)
dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes
'output' : {0 : 'batch_size'}})

print(f"Model successfully converted to ONNX format and saved at: {output_onnx_path}")

Ready to merge
This branch is ready to get merged automatically.

Sign up or log in to comment