abhishek-gola's picture
colorization-deploy-v2 (#1)
3eb0457
import torch
import torch.nn as nn
import numpy as np
from colorization.colorizers import eccv16
class ExactCaffeMatch(nn.Module):
def __init__(self):
super().__init__()
self.core = eccv16(pretrained=True).eval()
# Load the color palette kernel
pts_in_hull = np.load('opencv_extra/testdata/dnn/colorization_pts_in_hull.npy')
weight_tensor = torch.tensor(pts_in_hull.flatten()).float().view(2, 313, 1, 1)
self.register_buffer('decode_weight', weight_tensor)
def forward(self, x):
x = x / 100.0
x = self.core.model1(x)
x = self.core.model2(x)
x = self.core.model3(x)
x = self.core.model4(x)
x = self.core.model5(x)
x = self.core.model6(x)
x = self.core.model7(x)
x = self.core.model8(x)
# 1. Apply Caffe temperature scaling
x = x * 2.606# 2. Softmax
x = torch.softmax(x, dim=1)
x = torch.nn.functional.conv2d(x, self.decode_weight)
return x
model = ExactCaffeMatch()
dummy_input = torch.randn(1, 1, 224, 224)
torch.onnx.export(
model, dummy_input,
"colorization/colorization_deploy_v2_2026april.onnx",
export_params=True,
opset_version=11,
do_constant_folding=True,
input_names=['data_l'],
output_names=['class8_ab'],
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH
)