| import torch |
| import torch.nn as nn |
| import numpy as np |
| from colorization.colorizers import eccv16 |
|
|
| class ExactCaffeMatch(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.core = eccv16(pretrained=True).eval() |
|
|
| |
| pts_in_hull = np.load('opencv_extra/testdata/dnn/colorization_pts_in_hull.npy') |
|
|
| weight_tensor = torch.tensor(pts_in_hull.flatten()).float().view(2, 313, 1, 1) |
| self.register_buffer('decode_weight', weight_tensor) |
|
|
| def forward(self, x): |
| x = x / 100.0 |
| x = self.core.model1(x) |
| x = self.core.model2(x) |
| x = self.core.model3(x) |
| x = self.core.model4(x) |
| x = self.core.model5(x) |
| x = self.core.model6(x) |
| x = self.core.model7(x) |
| x = self.core.model8(x) |
|
|
| |
| x = x * 2.606 |
| x = torch.softmax(x, dim=1) |
|
|
| x = torch.nn.functional.conv2d(x, self.decode_weight) |
|
|
| return x |
|
|
| model = ExactCaffeMatch() |
| dummy_input = torch.randn(1, 1, 224, 224) |
|
|
| torch.onnx.export( |
| model, dummy_input, |
| "colorization/colorization_deploy_v2_2026april.onnx", |
| export_params=True, |
| opset_version=11, |
| do_constant_folding=True, |
| input_names=['data_l'], |
| output_names=['class8_ab'], |
| operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH |
| ) |