| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from colorization.colorizers import eccv16 | |
| class ExactCaffeMatch(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.core = eccv16(pretrained=True).eval() | |
| # Load the color palette kernel | |
| pts_in_hull = np.load('opencv_extra/testdata/dnn/colorization_pts_in_hull.npy') | |
| weight_tensor = torch.tensor(pts_in_hull.flatten()).float().view(2, 313, 1, 1) | |
| self.register_buffer('decode_weight', weight_tensor) | |
| def forward(self, x): | |
| x = x / 100.0 | |
| x = self.core.model1(x) | |
| x = self.core.model2(x) | |
| x = self.core.model3(x) | |
| x = self.core.model4(x) | |
| x = self.core.model5(x) | |
| x = self.core.model6(x) | |
| x = self.core.model7(x) | |
| x = self.core.model8(x) | |
| # 1. Apply Caffe temperature scaling | |
| x = x * 2.606# 2. Softmax | |
| x = torch.softmax(x, dim=1) | |
| x = torch.nn.functional.conv2d(x, self.decode_weight) | |
| return x | |
| model = ExactCaffeMatch() | |
| dummy_input = torch.randn(1, 1, 224, 224) | |
| torch.onnx.export( | |
| model, dummy_input, | |
| "colorization/colorization_deploy_v2_2026april.onnx", | |
| export_params=True, | |
| opset_version=11, | |
| do_constant_folding=True, | |
| input_names=['data_l'], | |
| output_names=['class8_ab'], | |
| operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH | |
| ) |