ukr-htr-convtext / export_onnx.py
Valerii Sielikhov
Initial release: Ukrainian OCR/ICR model with HF custom model/processor and ONNX export
b5b608e
from __future__ import annotations
import argparse
import json
import os
import torch
from configuration_htr import HTRConfig
from modeling_htr import HTRConvTextModel
class InferenceWrapper(torch.nn.Module):
def __init__(self, model: HTRConvTextModel):
super().__init__()
self.model = model
def forward(self, image: torch.Tensor) -> torch.Tensor:
return self.model(pixel_values=image, return_dict=True).logits
def main() -> None:
parser = argparse.ArgumentParser(description="Export HF model to ONNX.")
parser.add_argument(
"--hf-model-dir", required=True, help="Directory with HF artifacts."
)
parser.add_argument("--output-dir", default="onnx", help="ONNX output directory.")
parser.add_argument(
"--onnx-name", default="model.onnx", help="Output ONNX model filename."
)
parser.add_argument(
"--dummy-width", type=int, default=3072, help="Dummy input width for export."
)
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
model = HTRConvTextModel.from_pretrained(
args.hf_model_dir,
trust_remote_code=True,
low_cpu_mem_usage=False,
torch_dtype=torch.float32,
)
model.eval()
cfg = HTRConfig.from_pretrained(args.hf_model_dir, trust_remote_code=True)
dummy = torch.randn(1, 1, cfg.image_height, args.dummy_width)
wrapped = InferenceWrapper(model)
onnx_path = os.path.join(args.output_dir, args.onnx_name)
torch.onnx.export(
wrapped,
dummy,
onnx_path,
input_names=["image"],
output_names=["logits"],
dynamic_axes={
"image": {0: "batch", 3: "width"},
"logits": {0: "batch", 1: "timesteps"},
},
opset_version=18,
do_constant_folding=True,
export_params=True,
)
alphabet_path = os.path.join(args.hf_model_dir, "alphabet.json")
if os.path.isfile(alphabet_path):
with open(alphabet_path, "r", encoding="utf-8") as f:
alph = json.load(f)
with open(
os.path.join(args.output_dir, "alphabet.json"), "w", encoding="utf-8"
) as f:
json.dump(alph, f, ensure_ascii=False, indent=2)
print(f"ONNX exported to: {onnx_path}")
if __name__ == "__main__":
main()