""" Minimal ONNX Runtime inference example for AniFileBERT. The ONNX file outputs token logits only. End-to-end parsing still needs the repository tokenizer, constrained BIO decoding, and the same field aggregation used by anifilebert.inference. Usage: python -m tools.onnx_inference "[GM-Team][国漫][神印王座][Throne of Seal][2022][200][AVC][GB][1080P].mp4" """ import argparse import json from pathlib import Path from typing import Dict, List, Optional, Tuple import numpy as np import onnxruntime as ort import torch from anifilebert.inference import constrained_bio_decode, postprocess from anifilebert.tokenizer import AnimeTokenizer, load_tokenizer def encode( filename: str, tokenizer: AnimeTokenizer, max_length: int, ) -> Tuple[List[str], np.ndarray, np.ndarray, int]: tokens = tokenizer.tokenize(filename) available = min(len(tokens), max_length - 2) used_tokens = tokens[:available] input_ids = [tokenizer.cls_token_id] input_ids.extend(tokenizer.convert_tokens_to_ids(used_tokens)) input_ids.append(tokenizer.sep_token_id) attention_mask = [1] * len(input_ids) pad_len = max_length - len(input_ids) if pad_len > 0: input_ids.extend([tokenizer.pad_token_id] * pad_len) attention_mask.extend([0] * pad_len) return ( used_tokens, np.asarray([input_ids], dtype=np.int64), np.asarray([attention_mask], dtype=np.int64), available, ) def load_id2label(model_dir: Path) -> Dict[int, str]: config = json.loads((model_dir / "config.json").read_text(encoding="utf-8")) return {int(label_id): label for label_id, label in config["id2label"].items()} def parse_with_onnx( filename: str, model_dir: Path, onnx_path: Path, max_length: int, ) -> Dict: parser = OnnxFilenameParser(model_dir, onnx_path, max_length) return parser.parse(filename) class OnnxFilenameParser: """Reusable ONNX Runtime parser with tokenizer and session loaded once.""" def __init__( self, model_dir: Path, onnx_path: Path, max_length: int, providers: List[str] | None = None, session_options: Optional[ort.SessionOptions] = None, ) -> None: self.model_dir = model_dir self.onnx_path = onnx_path self.max_length = max_length self.tokenizer = load_tokenizer(str(model_dir)) self.id2label = load_id2label(model_dir) self.session = ort.InferenceSession( str(onnx_path), sess_options=session_options, providers=providers or ["CPUExecutionProvider"], ) def parse(self, filename: str) -> Dict: tokens, input_ids, attention_mask, available = encode(filename, self.tokenizer, self.max_length) logits = self.session.run( ["logits"], { "input_ids": input_ids, "attention_mask": attention_mask, }, )[0] token_logits = torch.from_numpy(logits[0, 1:1 + available, :]) label_ids = constrained_bio_decode(token_logits, self.id2label) labels = [self.id2label.get(label_id, "O") for label_id in label_ids] result = postprocess(tokens, labels, tokenizer=self.tokenizer) result["_input"] = filename return result def main() -> None: parser = argparse.ArgumentParser(description="Run AniFileBERT ONNX inference") parser.add_argument("filename", help="Anime filename to parse") parser.add_argument("--model-dir", default=".", help="Directory containing vocab.json and config.json") parser.add_argument("--onnx", default="exports/anime_filename_parser.onnx", help="ONNX model path") parser.add_argument("--max-length", type=int, default=128, help="Static ONNX sequence length") args = parser.parse_args() result = parse_with_onnx( filename=args.filename, model_dir=Path(args.model_dir), onnx_path=Path(args.onnx), max_length=args.max_length, ) print(json.dumps(result, ensure_ascii=False)) if __name__ == "__main__": main()