omrope792 commited on
Commit
2a842b0
·
1 Parent(s): 0ec930e

Add: model conversion script

Browse files
submissions/colorization/model_conversion.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from colorization.colorizers import eccv16
5
+
6
+ class ExactCaffeMatch(nn.Module):
7
+ def __init__(self):
8
+ super().__init__()
9
+ self.core = eccv16(pretrained=True).eval()
10
+
11
+ # Load the color palette kernel
12
+ pts_in_hull = np.load('opencv_extra/testdata/dnn/colorization_pts_in_hull.npy')
13
+
14
+ weight_tensor = torch.tensor(pts_in_hull.flatten()).float().view(2, 313, 1, 1)
15
+ self.register_buffer('decode_weight', weight_tensor)
16
+
17
+ def forward(self, x):
18
+ x = x / 100.0
19
+ x = self.core.model1(x)
20
+ x = self.core.model2(x)
21
+ x = self.core.model3(x)
22
+ x = self.core.model4(x)
23
+ x = self.core.model5(x)
24
+ x = self.core.model6(x)
25
+ x = self.core.model7(x)
26
+ x = self.core.model8(x)
27
+
28
+ # 1. Apply Caffe temperature scaling
29
+ x = x * 2.606# 2. Softmax
30
+ x = torch.softmax(x, dim=1)
31
+
32
+ x = torch.nn.functional.conv2d(x, self.decode_weight)
33
+
34
+ return x
35
+
36
+ model = ExactCaffeMatch()
37
+ dummy_input = torch.randn(1, 1, 224, 224)
38
+
39
+ torch.onnx.export(
40
+ model, dummy_input,
41
+ "colorization/colorization_deploy_v2_2026april.onnx",
42
+ export_params=True,
43
+ opset_version=11,
44
+ do_constant_folding=True,
45
+ input_names=['data_l'],
46
+ output_names=['class8_ab'],
47
+ operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH
48
+ )