File size: 1,395 Bytes
3eb0457 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 | import torch
import torch.nn as nn
import numpy as np
from colorization.colorizers import eccv16
class ExactCaffeMatch(nn.Module):
def __init__(self):
super().__init__()
self.core = eccv16(pretrained=True).eval()
# Load the color palette kernel
pts_in_hull = np.load('opencv_extra/testdata/dnn/colorization_pts_in_hull.npy')
weight_tensor = torch.tensor(pts_in_hull.flatten()).float().view(2, 313, 1, 1)
self.register_buffer('decode_weight', weight_tensor)
def forward(self, x):
x = x / 100.0
x = self.core.model1(x)
x = self.core.model2(x)
x = self.core.model3(x)
x = self.core.model4(x)
x = self.core.model5(x)
x = self.core.model6(x)
x = self.core.model7(x)
x = self.core.model8(x)
# 1. Apply Caffe temperature scaling
x = x * 2.606# 2. Softmax
x = torch.softmax(x, dim=1)
x = torch.nn.functional.conv2d(x, self.decode_weight)
return x
model = ExactCaffeMatch()
dummy_input = torch.randn(1, 1, 224, 224)
torch.onnx.export(
model, dummy_input,
"colorization/colorization_deploy_v2_2026april.onnx",
export_params=True,
opset_version=11,
do_constant_folding=True,
input_names=['data_l'],
output_names=['class8_ab'],
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH
) |