AniFileBERT / tools /onnx_inference.py
ModerRAS's picture
Organize parser modules and tools
8c50d16
"""
Minimal ONNX Runtime inference example for AniFileBERT.
The ONNX file outputs token logits only. End-to-end parsing still needs the
repository tokenizer, constrained BIO decoding, and the same field aggregation
used by anifilebert.inference.
Usage:
python -m tools.onnx_inference "[GM-Team][国漫][神印王座][Throne of Seal][2022][200][AVC][GB][1080P].mp4"
"""
import argparse
import json
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np
import onnxruntime as ort
import torch
from anifilebert.inference import constrained_bio_decode, postprocess
from anifilebert.tokenizer import AnimeTokenizer, load_tokenizer
def encode(
filename: str,
tokenizer: AnimeTokenizer,
max_length: int,
) -> Tuple[List[str], np.ndarray, np.ndarray, int]:
tokens = tokenizer.tokenize(filename)
available = min(len(tokens), max_length - 2)
used_tokens = tokens[:available]
input_ids = [tokenizer.cls_token_id]
input_ids.extend(tokenizer.convert_tokens_to_ids(used_tokens))
input_ids.append(tokenizer.sep_token_id)
attention_mask = [1] * len(input_ids)
pad_len = max_length - len(input_ids)
if pad_len > 0:
input_ids.extend([tokenizer.pad_token_id] * pad_len)
attention_mask.extend([0] * pad_len)
return (
used_tokens,
np.asarray([input_ids], dtype=np.int64),
np.asarray([attention_mask], dtype=np.int64),
available,
)
def load_id2label(model_dir: Path) -> Dict[int, str]:
config = json.loads((model_dir / "config.json").read_text(encoding="utf-8"))
return {int(label_id): label for label_id, label in config["id2label"].items()}
def parse_with_onnx(
filename: str,
model_dir: Path,
onnx_path: Path,
max_length: int,
) -> Dict:
parser = OnnxFilenameParser(model_dir, onnx_path, max_length)
return parser.parse(filename)
class OnnxFilenameParser:
"""Reusable ONNX Runtime parser with tokenizer and session loaded once."""
def __init__(
self,
model_dir: Path,
onnx_path: Path,
max_length: int,
providers: List[str] | None = None,
session_options: Optional[ort.SessionOptions] = None,
) -> None:
self.model_dir = model_dir
self.onnx_path = onnx_path
self.max_length = max_length
self.tokenizer = load_tokenizer(str(model_dir))
self.id2label = load_id2label(model_dir)
self.session = ort.InferenceSession(
str(onnx_path),
sess_options=session_options,
providers=providers or ["CPUExecutionProvider"],
)
def parse(self, filename: str) -> Dict:
tokens, input_ids, attention_mask, available = encode(filename, self.tokenizer, self.max_length)
logits = self.session.run(
["logits"],
{
"input_ids": input_ids,
"attention_mask": attention_mask,
},
)[0]
token_logits = torch.from_numpy(logits[0, 1:1 + available, :])
label_ids = constrained_bio_decode(token_logits, self.id2label)
labels = [self.id2label.get(label_id, "O") for label_id in label_ids]
result = postprocess(tokens, labels, tokenizer=self.tokenizer)
result["_input"] = filename
return result
def main() -> None:
parser = argparse.ArgumentParser(description="Run AniFileBERT ONNX inference")
parser.add_argument("filename", help="Anime filename to parse")
parser.add_argument("--model-dir", default=".", help="Directory containing vocab.json and config.json")
parser.add_argument("--onnx", default="exports/anime_filename_parser.onnx", help="ONNX model path")
parser.add_argument("--max-length", type=int, default=128, help="Static ONNX sequence length")
args = parser.parse_args()
result = parse_with_onnx(
filename=args.filename,
model_dir=Path(args.model_dir),
onnx_path=Path(args.onnx),
max_length=args.max_length,
)
print(json.dumps(result, ensure_ascii=False))
if __name__ == "__main__":
main()