File size: 1,395 Bytes
3eb0457
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
import torch.nn as nn
import numpy as np
from colorization.colorizers import eccv16

class ExactCaffeMatch(nn.Module):
    def __init__(self):
        super().__init__()
        self.core = eccv16(pretrained=True).eval()

        # Load the color palette kernel
        pts_in_hull = np.load('opencv_extra/testdata/dnn/colorization_pts_in_hull.npy')

        weight_tensor = torch.tensor(pts_in_hull.flatten()).float().view(2, 313, 1, 1)
        self.register_buffer('decode_weight', weight_tensor)

    def forward(self, x):
        x = x / 100.0
        x = self.core.model1(x)
        x = self.core.model2(x)
        x = self.core.model3(x)
        x = self.core.model4(x)
        x = self.core.model5(x)
        x = self.core.model6(x)
        x = self.core.model7(x)
        x = self.core.model8(x)

        # 1. Apply Caffe temperature scaling
        x = x * 2.606# 2. Softmax
        x = torch.softmax(x, dim=1)

        x = torch.nn.functional.conv2d(x, self.decode_weight)

        return x

model = ExactCaffeMatch()
dummy_input = torch.randn(1, 1, 224, 224)

torch.onnx.export(
    model, dummy_input,
    "colorization/colorization_deploy_v2_2026april.onnx",
    export_params=True,
    opset_version=11,
    do_constant_folding=True,
    input_names=['data_l'],
    output_names=['class8_ab'],
    operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH
)