Valerii Sielikhov
Initial release: Ukrainian OCR/ICR model with HF custom model/processor and ONNX export
b5b608e | from __future__ import annotations | |
| import argparse | |
| import json | |
| import os | |
| import torch | |
| from configuration_htr import HTRConfig | |
| from modeling_htr import HTRConvTextModel | |
| class InferenceWrapper(torch.nn.Module): | |
| def __init__(self, model: HTRConvTextModel): | |
| super().__init__() | |
| self.model = model | |
| def forward(self, image: torch.Tensor) -> torch.Tensor: | |
| return self.model(pixel_values=image, return_dict=True).logits | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Export HF model to ONNX.") | |
| parser.add_argument( | |
| "--hf-model-dir", required=True, help="Directory with HF artifacts." | |
| ) | |
| parser.add_argument("--output-dir", default="onnx", help="ONNX output directory.") | |
| parser.add_argument( | |
| "--onnx-name", default="model.onnx", help="Output ONNX model filename." | |
| ) | |
| parser.add_argument( | |
| "--dummy-width", type=int, default=3072, help="Dummy input width for export." | |
| ) | |
| args = parser.parse_args() | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| model = HTRConvTextModel.from_pretrained( | |
| args.hf_model_dir, | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=False, | |
| torch_dtype=torch.float32, | |
| ) | |
| model.eval() | |
| cfg = HTRConfig.from_pretrained(args.hf_model_dir, trust_remote_code=True) | |
| dummy = torch.randn(1, 1, cfg.image_height, args.dummy_width) | |
| wrapped = InferenceWrapper(model) | |
| onnx_path = os.path.join(args.output_dir, args.onnx_name) | |
| torch.onnx.export( | |
| wrapped, | |
| dummy, | |
| onnx_path, | |
| input_names=["image"], | |
| output_names=["logits"], | |
| dynamic_axes={ | |
| "image": {0: "batch", 3: "width"}, | |
| "logits": {0: "batch", 1: "timesteps"}, | |
| }, | |
| opset_version=18, | |
| do_constant_folding=True, | |
| export_params=True, | |
| ) | |
| alphabet_path = os.path.join(args.hf_model_dir, "alphabet.json") | |
| if os.path.isfile(alphabet_path): | |
| with open(alphabet_path, "r", encoding="utf-8") as f: | |
| alph = json.load(f) | |
| with open( | |
| os.path.join(args.output_dir, "alphabet.json"), "w", encoding="utf-8" | |
| ) as f: | |
| json.dump(alph, f, ensure_ascii=False, indent=2) | |
| print(f"ONNX exported to: {onnx_path}") | |
| if __name__ == "__main__": | |
| main() | |