ONNX
text-detection
craft
inference4j
craft-mlt-25k / craft_exporter.py
vccarvalho11's picture
Upload CRAFT MLT 25K ONNX model
99781a5 verified
"""
CRAFT (Character Region Awareness for Text Detection) — ONNX Export Script
Exports the CRAFT MLT 25K model from PyTorch to ONNX format.
Usage:
1. Download weights from https://drive.google.com/uc?id=1Jk4eGD7crsqCCg9C9VjCLkMN3ze8kutZ
(or use gdown: gdown 1Jk4eGD7crsqCCg9C9VjCLkMN3ze8kutZ -O craft_mlt_25k.pth)
2. pip install torch torchvision onnx onnxruntime
3. python craft_exporter.py
Original weights: clovaai/CRAFT-pytorch (https://github.com/clovaai/CRAFT-pytorch)
"""
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from torchvision import models
# ============================================================
# Model definitions matching clovaai/CRAFT-pytorch exactly
# ============================================================
def init_weights(modules):
for m in modules:
if isinstance(m, nn.Conv2d):
nn.init.xavier_uniform_(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
class VGG16BN(nn.Module):
def __init__(self, pretrained=False):
super().__init__()
vgg_pretrained_features = models.vgg16_bn(pretrained=False).features
self.slice1 = nn.Sequential()
self.slice2 = nn.Sequential()
self.slice3 = nn.Sequential()
self.slice4 = nn.Sequential()
self.slice5 = nn.Sequential()
# Use add_module with original indices to match state_dict keys
for x in range(12): # conv2_2
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(12, 19): # conv3_3
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(19, 29): # conv4_3
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(29, 39): # conv5_3
self.slice4.add_module(str(x), vgg_pretrained_features[x])
# fc6, fc7 without atrous conv
self.slice5 = nn.Sequential(
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6),
nn.Conv2d(1024, 1024, kernel_size=1),
)
init_weights(self.slice5.modules())
def forward(self, x):
h = self.slice1(x)
h_relu2_2 = h
h = self.slice2(h)
h_relu3_2 = h
h = self.slice3(h)
h_relu4_3 = h
h = self.slice4(h)
h_relu5_3 = h
h = self.slice5(h)
h_fc7 = h
# Return order: fc7, relu5_3, relu4_3, relu3_2, relu2_2
return h_fc7, h_relu5_3, h_relu4_3, h_relu3_2, h_relu2_2
class DoubleConv(nn.Module):
def __init__(self, in_ch, mid_ch, out_ch):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch + mid_ch, mid_ch, kernel_size=1),
nn.BatchNorm2d(mid_ch),
nn.ReLU(inplace=True),
nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
)
def forward(self, x):
return self.conv(x)
class CRAFT(nn.Module):
def __init__(self):
super().__init__()
self.basenet = VGG16BN()
# U network
self.upconv1 = DoubleConv(1024, 512, 256)
self.upconv2 = DoubleConv(512, 256, 128)
self.upconv3 = DoubleConv(256, 128, 64)
self.upconv4 = DoubleConv(128, 64, 32)
num_class = 2
self.conv_cls = nn.Sequential(
nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
nn.Conv2d(32, 16, kernel_size=3, padding=1), nn.ReLU(inplace=True),
nn.Conv2d(16, 16, kernel_size=1), nn.ReLU(inplace=True),
nn.Conv2d(16, num_class, kernel_size=1),
)
init_weights(self.upconv1.modules())
init_weights(self.upconv2.modules())
init_weights(self.upconv3.modules())
init_weights(self.upconv4.modules())
init_weights(self.conv_cls.modules())
def forward(self, x):
# Base network
sources = self.basenet(x)
# sources = (fc7, relu5_3, relu4_3, relu3_2, relu2_2)
# U network
y = torch.cat([sources[0], sources[1]], dim=1)
y = self.upconv1(y)
y = F.interpolate(y, size=sources[2].size()[2:], mode='bilinear', align_corners=False)
y = torch.cat([y, sources[2]], dim=1)
y = self.upconv2(y)
y = F.interpolate(y, size=sources[3].size()[2:], mode='bilinear', align_corners=False)
y = torch.cat([y, sources[3]], dim=1)
y = self.upconv3(y)
y = F.interpolate(y, size=sources[4].size()[2:], mode='bilinear', align_corners=False)
y = torch.cat([y, sources[4]], dim=1)
feature = self.upconv4(y)
y = self.conv_cls(feature)
return y.permute(0, 2, 3, 1), feature
# ============================================================
# Export and validate
# ============================================================
WEIGHTS_PATH = "craft_mlt_25k.pth"
OUTPUT_PATH = "model.onnx"
def load_model():
model = CRAFT()
state_dict = torch.load(WEIGHTS_PATH, map_location="cpu", weights_only=True)
# Handle DataParallel 'module.' prefix
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k.replace("module.", "")
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
model.eval()
return model
def export_onnx(model):
dummy_input = torch.randn(1, 3, 640, 640)
torch.onnx.export(
model,
dummy_input,
OUTPUT_PATH,
opset_version=17,
input_names=["input"],
output_names=["score_map", "feature_map"],
dynamic_axes={
"input": {0: "batch", 2: "height", 3: "width"},
"score_map": {0: "batch", 1: "height", 2: "width"},
"feature_map": {0: "batch", 2: "height", 3: "width"},
},
)
print(f"Exported to {OUTPUT_PATH}")
def validate():
import onnxruntime as ort
import numpy as np
session = ort.InferenceSession(OUTPUT_PATH)
dummy = np.random.randn(1, 3, 640, 640).astype(np.float32)
results = session.run(None, {"input": dummy})
print(f"Validation OK:")
print(f" score_map shape: {results[0].shape}") # (1, 320, 320, 2)
print(f" feature_map shape: {results[1].shape}") # (1, 32, 320, 320)
if __name__ == "__main__":
if not os.path.exists(WEIGHTS_PATH):
print(f"Download weights first:")
print(f" gdown 1Jk4eGD7crsqCCg9C9VjCLkMN3ze8kutZ -O {WEIGHTS_PATH}")
print(f" (from https://github.com/clovaai/CRAFT-pytorch)")
exit(1)
model = load_model()
export_onnx(model)
validate()