| """ |
| CRAFT (Character Region Awareness for Text Detection) — ONNX Export Script |
| |
| Exports the CRAFT MLT 25K model from PyTorch to ONNX format. |
| |
| Usage: |
| 1. Download weights from https://drive.google.com/uc?id=1Jk4eGD7crsqCCg9C9VjCLkMN3ze8kutZ |
| (or use gdown: gdown 1Jk4eGD7crsqCCg9C9VjCLkMN3ze8kutZ -O craft_mlt_25k.pth) |
| 2. pip install torch torchvision onnx onnxruntime |
| 3. python craft_exporter.py |
| |
| Original weights: clovaai/CRAFT-pytorch (https://github.com/clovaai/CRAFT-pytorch) |
| """ |
|
|
| import os |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from collections import OrderedDict |
| from torchvision import models |
|
|
|
|
| |
| |
| |
|
|
| def init_weights(modules): |
| for m in modules: |
| if isinstance(m, nn.Conv2d): |
| nn.init.xavier_uniform_(m.weight.data) |
| if m.bias is not None: |
| m.bias.data.zero_() |
| elif isinstance(m, nn.BatchNorm2d): |
| m.weight.data.fill_(1) |
| m.bias.data.zero_() |
|
|
|
|
| class VGG16BN(nn.Module): |
| def __init__(self, pretrained=False): |
| super().__init__() |
| vgg_pretrained_features = models.vgg16_bn(pretrained=False).features |
|
|
| self.slice1 = nn.Sequential() |
| self.slice2 = nn.Sequential() |
| self.slice3 = nn.Sequential() |
| self.slice4 = nn.Sequential() |
| self.slice5 = nn.Sequential() |
|
|
| |
| for x in range(12): |
| self.slice1.add_module(str(x), vgg_pretrained_features[x]) |
| for x in range(12, 19): |
| self.slice2.add_module(str(x), vgg_pretrained_features[x]) |
| for x in range(19, 29): |
| self.slice3.add_module(str(x), vgg_pretrained_features[x]) |
| for x in range(29, 39): |
| self.slice4.add_module(str(x), vgg_pretrained_features[x]) |
|
|
| |
| self.slice5 = nn.Sequential( |
| nn.MaxPool2d(kernel_size=3, stride=1, padding=1), |
| nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6), |
| nn.Conv2d(1024, 1024, kernel_size=1), |
| ) |
|
|
| init_weights(self.slice5.modules()) |
|
|
| def forward(self, x): |
| h = self.slice1(x) |
| h_relu2_2 = h |
| h = self.slice2(h) |
| h_relu3_2 = h |
| h = self.slice3(h) |
| h_relu4_3 = h |
| h = self.slice4(h) |
| h_relu5_3 = h |
| h = self.slice5(h) |
| h_fc7 = h |
| |
| return h_fc7, h_relu5_3, h_relu4_3, h_relu3_2, h_relu2_2 |
|
|
|
|
| class DoubleConv(nn.Module): |
| def __init__(self, in_ch, mid_ch, out_ch): |
| super().__init__() |
| self.conv = nn.Sequential( |
| nn.Conv2d(in_ch + mid_ch, mid_ch, kernel_size=1), |
| nn.BatchNorm2d(mid_ch), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1), |
| nn.BatchNorm2d(out_ch), |
| nn.ReLU(inplace=True), |
| ) |
|
|
| def forward(self, x): |
| return self.conv(x) |
|
|
|
|
| class CRAFT(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.basenet = VGG16BN() |
|
|
| |
| self.upconv1 = DoubleConv(1024, 512, 256) |
| self.upconv2 = DoubleConv(512, 256, 128) |
| self.upconv3 = DoubleConv(256, 128, 64) |
| self.upconv4 = DoubleConv(128, 64, 32) |
|
|
| num_class = 2 |
| self.conv_cls = nn.Sequential( |
| nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True), |
| nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True), |
| nn.Conv2d(32, 16, kernel_size=3, padding=1), nn.ReLU(inplace=True), |
| nn.Conv2d(16, 16, kernel_size=1), nn.ReLU(inplace=True), |
| nn.Conv2d(16, num_class, kernel_size=1), |
| ) |
|
|
| init_weights(self.upconv1.modules()) |
| init_weights(self.upconv2.modules()) |
| init_weights(self.upconv3.modules()) |
| init_weights(self.upconv4.modules()) |
| init_weights(self.conv_cls.modules()) |
|
|
| def forward(self, x): |
| |
| sources = self.basenet(x) |
| |
|
|
| |
| y = torch.cat([sources[0], sources[1]], dim=1) |
| y = self.upconv1(y) |
|
|
| y = F.interpolate(y, size=sources[2].size()[2:], mode='bilinear', align_corners=False) |
| y = torch.cat([y, sources[2]], dim=1) |
| y = self.upconv2(y) |
|
|
| y = F.interpolate(y, size=sources[3].size()[2:], mode='bilinear', align_corners=False) |
| y = torch.cat([y, sources[3]], dim=1) |
| y = self.upconv3(y) |
|
|
| y = F.interpolate(y, size=sources[4].size()[2:], mode='bilinear', align_corners=False) |
| y = torch.cat([y, sources[4]], dim=1) |
| feature = self.upconv4(y) |
|
|
| y = self.conv_cls(feature) |
|
|
| return y.permute(0, 2, 3, 1), feature |
|
|
|
|
| |
| |
| |
|
|
| WEIGHTS_PATH = "craft_mlt_25k.pth" |
| OUTPUT_PATH = "model.onnx" |
|
|
|
|
| def load_model(): |
| model = CRAFT() |
| state_dict = torch.load(WEIGHTS_PATH, map_location="cpu", weights_only=True) |
|
|
| |
| new_state_dict = OrderedDict() |
| for k, v in state_dict.items(): |
| name = k.replace("module.", "") |
| new_state_dict[name] = v |
|
|
| model.load_state_dict(new_state_dict) |
| model.eval() |
| return model |
|
|
|
|
| def export_onnx(model): |
| dummy_input = torch.randn(1, 3, 640, 640) |
|
|
| torch.onnx.export( |
| model, |
| dummy_input, |
| OUTPUT_PATH, |
| opset_version=17, |
| input_names=["input"], |
| output_names=["score_map", "feature_map"], |
| dynamic_axes={ |
| "input": {0: "batch", 2: "height", 3: "width"}, |
| "score_map": {0: "batch", 1: "height", 2: "width"}, |
| "feature_map": {0: "batch", 2: "height", 3: "width"}, |
| }, |
| ) |
| print(f"Exported to {OUTPUT_PATH}") |
|
|
|
|
| def validate(): |
| import onnxruntime as ort |
| import numpy as np |
|
|
| session = ort.InferenceSession(OUTPUT_PATH) |
| dummy = np.random.randn(1, 3, 640, 640).astype(np.float32) |
| results = session.run(None, {"input": dummy}) |
| print(f"Validation OK:") |
| print(f" score_map shape: {results[0].shape}") |
| print(f" feature_map shape: {results[1].shape}") |
|
|
|
|
| if __name__ == "__main__": |
| if not os.path.exists(WEIGHTS_PATH): |
| print(f"Download weights first:") |
| print(f" gdown 1Jk4eGD7crsqCCg9C9VjCLkMN3ze8kutZ -O {WEIGHTS_PATH}") |
| print(f" (from https://github.com/clovaai/CRAFT-pytorch)") |
| exit(1) |
|
|
| model = load_model() |
| export_onnx(model) |
| validate() |
|
|