ONNX
text-detection
craft
inference4j
vccarvalho11 commited on
Commit
99781a5
·
verified ·
1 Parent(s): 97a2be5

Upload CRAFT MLT 25K ONNX model

Browse files
Files changed (2) hide show
  1. README.md +15 -0
  2. craft_exporter.py +214 -0
README.md CHANGED
@@ -62,6 +62,21 @@ try (Craft craft = Craft.fromPretrained("models/craft-mlt-25k")) {
62
  4. For each component: compute mean region score, filter by `text_threshold` (default 0.7)
63
  5. Extract axis-aligned bounding box, scale back to original image coordinates
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  ## Original Paper
66
 
67
  > Baek, Y., Lee, B., Han, D., Yun, S., & Lee, H. (2019).
 
62
  4. For each component: compute mean region score, filter by `text_threshold` (default 0.7)
63
  5. Extract axis-aligned bounding box, scale back to original image coordinates
64
 
65
+ ## Conversion
66
+
67
+ This model was converted from PyTorch to ONNX using [`craft_exporter.py`](craft_exporter.py) included in this repo. To reproduce:
68
+
69
+ ```bash
70
+ # Download original PyTorch weights (~79 MB)
71
+ gdown 1Jk4eGD7crsqCCg9C9VjCLkMN3ze8kutZ -O craft_mlt_25k.pth
72
+
73
+ # Install dependencies
74
+ pip install torch torchvision onnx onnxruntime
75
+
76
+ # Export to ONNX
77
+ python craft_exporter.py
78
+ ```
79
+
80
  ## Original Paper
81
 
82
  > Baek, Y., Lee, B., Han, D., Yun, S., & Lee, H. (2019).
craft_exporter.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CRAFT (Character Region Awareness for Text Detection) — ONNX Export Script
3
+
4
+ Exports the CRAFT MLT 25K model from PyTorch to ONNX format.
5
+
6
+ Usage:
7
+ 1. Download weights from https://drive.google.com/uc?id=1Jk4eGD7crsqCCg9C9VjCLkMN3ze8kutZ
8
+ (or use gdown: gdown 1Jk4eGD7crsqCCg9C9VjCLkMN3ze8kutZ -O craft_mlt_25k.pth)
9
+ 2. pip install torch torchvision onnx onnxruntime
10
+ 3. python craft_exporter.py
11
+
12
+ Original weights: clovaai/CRAFT-pytorch (https://github.com/clovaai/CRAFT-pytorch)
13
+ """
14
+
15
+ import os
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from collections import OrderedDict
20
+ from torchvision import models
21
+
22
+
23
+ # ============================================================
24
+ # Model definitions matching clovaai/CRAFT-pytorch exactly
25
+ # ============================================================
26
+
27
+ def init_weights(modules):
28
+ for m in modules:
29
+ if isinstance(m, nn.Conv2d):
30
+ nn.init.xavier_uniform_(m.weight.data)
31
+ if m.bias is not None:
32
+ m.bias.data.zero_()
33
+ elif isinstance(m, nn.BatchNorm2d):
34
+ m.weight.data.fill_(1)
35
+ m.bias.data.zero_()
36
+
37
+
38
+ class VGG16BN(nn.Module):
39
+ def __init__(self, pretrained=False):
40
+ super().__init__()
41
+ vgg_pretrained_features = models.vgg16_bn(pretrained=False).features
42
+
43
+ self.slice1 = nn.Sequential()
44
+ self.slice2 = nn.Sequential()
45
+ self.slice3 = nn.Sequential()
46
+ self.slice4 = nn.Sequential()
47
+ self.slice5 = nn.Sequential()
48
+
49
+ # Use add_module with original indices to match state_dict keys
50
+ for x in range(12): # conv2_2
51
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
52
+ for x in range(12, 19): # conv3_3
53
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
54
+ for x in range(19, 29): # conv4_3
55
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
56
+ for x in range(29, 39): # conv5_3
57
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
58
+
59
+ # fc6, fc7 without atrous conv
60
+ self.slice5 = nn.Sequential(
61
+ nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
62
+ nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6),
63
+ nn.Conv2d(1024, 1024, kernel_size=1),
64
+ )
65
+
66
+ init_weights(self.slice5.modules())
67
+
68
+ def forward(self, x):
69
+ h = self.slice1(x)
70
+ h_relu2_2 = h
71
+ h = self.slice2(h)
72
+ h_relu3_2 = h
73
+ h = self.slice3(h)
74
+ h_relu4_3 = h
75
+ h = self.slice4(h)
76
+ h_relu5_3 = h
77
+ h = self.slice5(h)
78
+ h_fc7 = h
79
+ # Return order: fc7, relu5_3, relu4_3, relu3_2, relu2_2
80
+ return h_fc7, h_relu5_3, h_relu4_3, h_relu3_2, h_relu2_2
81
+
82
+
83
+ class DoubleConv(nn.Module):
84
+ def __init__(self, in_ch, mid_ch, out_ch):
85
+ super().__init__()
86
+ self.conv = nn.Sequential(
87
+ nn.Conv2d(in_ch + mid_ch, mid_ch, kernel_size=1),
88
+ nn.BatchNorm2d(mid_ch),
89
+ nn.ReLU(inplace=True),
90
+ nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1),
91
+ nn.BatchNorm2d(out_ch),
92
+ nn.ReLU(inplace=True),
93
+ )
94
+
95
+ def forward(self, x):
96
+ return self.conv(x)
97
+
98
+
99
+ class CRAFT(nn.Module):
100
+ def __init__(self):
101
+ super().__init__()
102
+ self.basenet = VGG16BN()
103
+
104
+ # U network
105
+ self.upconv1 = DoubleConv(1024, 512, 256)
106
+ self.upconv2 = DoubleConv(512, 256, 128)
107
+ self.upconv3 = DoubleConv(256, 128, 64)
108
+ self.upconv4 = DoubleConv(128, 64, 32)
109
+
110
+ num_class = 2
111
+ self.conv_cls = nn.Sequential(
112
+ nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
113
+ nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
114
+ nn.Conv2d(32, 16, kernel_size=3, padding=1), nn.ReLU(inplace=True),
115
+ nn.Conv2d(16, 16, kernel_size=1), nn.ReLU(inplace=True),
116
+ nn.Conv2d(16, num_class, kernel_size=1),
117
+ )
118
+
119
+ init_weights(self.upconv1.modules())
120
+ init_weights(self.upconv2.modules())
121
+ init_weights(self.upconv3.modules())
122
+ init_weights(self.upconv4.modules())
123
+ init_weights(self.conv_cls.modules())
124
+
125
+ def forward(self, x):
126
+ # Base network
127
+ sources = self.basenet(x)
128
+ # sources = (fc7, relu5_3, relu4_3, relu3_2, relu2_2)
129
+
130
+ # U network
131
+ y = torch.cat([sources[0], sources[1]], dim=1)
132
+ y = self.upconv1(y)
133
+
134
+ y = F.interpolate(y, size=sources[2].size()[2:], mode='bilinear', align_corners=False)
135
+ y = torch.cat([y, sources[2]], dim=1)
136
+ y = self.upconv2(y)
137
+
138
+ y = F.interpolate(y, size=sources[3].size()[2:], mode='bilinear', align_corners=False)
139
+ y = torch.cat([y, sources[3]], dim=1)
140
+ y = self.upconv3(y)
141
+
142
+ y = F.interpolate(y, size=sources[4].size()[2:], mode='bilinear', align_corners=False)
143
+ y = torch.cat([y, sources[4]], dim=1)
144
+ feature = self.upconv4(y)
145
+
146
+ y = self.conv_cls(feature)
147
+
148
+ return y.permute(0, 2, 3, 1), feature
149
+
150
+
151
+ # ============================================================
152
+ # Export and validate
153
+ # ============================================================
154
+
155
+ WEIGHTS_PATH = "craft_mlt_25k.pth"
156
+ OUTPUT_PATH = "model.onnx"
157
+
158
+
159
+ def load_model():
160
+ model = CRAFT()
161
+ state_dict = torch.load(WEIGHTS_PATH, map_location="cpu", weights_only=True)
162
+
163
+ # Handle DataParallel 'module.' prefix
164
+ new_state_dict = OrderedDict()
165
+ for k, v in state_dict.items():
166
+ name = k.replace("module.", "")
167
+ new_state_dict[name] = v
168
+
169
+ model.load_state_dict(new_state_dict)
170
+ model.eval()
171
+ return model
172
+
173
+
174
+ def export_onnx(model):
175
+ dummy_input = torch.randn(1, 3, 640, 640)
176
+
177
+ torch.onnx.export(
178
+ model,
179
+ dummy_input,
180
+ OUTPUT_PATH,
181
+ opset_version=17,
182
+ input_names=["input"],
183
+ output_names=["score_map", "feature_map"],
184
+ dynamic_axes={
185
+ "input": {0: "batch", 2: "height", 3: "width"},
186
+ "score_map": {0: "batch", 1: "height", 2: "width"},
187
+ "feature_map": {0: "batch", 2: "height", 3: "width"},
188
+ },
189
+ )
190
+ print(f"Exported to {OUTPUT_PATH}")
191
+
192
+
193
+ def validate():
194
+ import onnxruntime as ort
195
+ import numpy as np
196
+
197
+ session = ort.InferenceSession(OUTPUT_PATH)
198
+ dummy = np.random.randn(1, 3, 640, 640).astype(np.float32)
199
+ results = session.run(None, {"input": dummy})
200
+ print(f"Validation OK:")
201
+ print(f" score_map shape: {results[0].shape}") # (1, 320, 320, 2)
202
+ print(f" feature_map shape: {results[1].shape}") # (1, 32, 320, 320)
203
+
204
+
205
+ if __name__ == "__main__":
206
+ if not os.path.exists(WEIGHTS_PATH):
207
+ print(f"Download weights first:")
208
+ print(f" gdown 1Jk4eGD7crsqCCg9C9VjCLkMN3ze8kutZ -O {WEIGHTS_PATH}")
209
+ print(f" (from https://github.com/clovaai/CRAFT-pytorch)")
210
+ exit(1)
211
+
212
+ model = load_model()
213
+ export_onnx(model)
214
+ validate()