File size: 4,113 Bytes
376db19
 
 
 
 
8c50d16
376db19
 
8c50d16
376db19
 
 
 
 
ce3a60d
376db19
 
 
 
 
8c50d16
 
376db19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce3a60d
116c87c
ce3a60d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116c87c
ce3a60d
 
 
 
 
 
 
 
 
 
 
 
116c87c
ce3a60d
 
376db19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c50d16
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""
Minimal ONNX Runtime inference example for AniFileBERT.

The ONNX file outputs token logits only. End-to-end parsing still needs the
repository tokenizer, constrained BIO decoding, and the same field aggregation
used by anifilebert.inference.

Usage:
    python -m tools.onnx_inference "[GM-Team][国漫][神印王座][Throne of Seal][2022][200][AVC][GB][1080P].mp4"
"""

import argparse
import json
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import numpy as np
import onnxruntime as ort
import torch

from anifilebert.inference import constrained_bio_decode, postprocess
from anifilebert.tokenizer import AnimeTokenizer, load_tokenizer


def encode(
    filename: str,
    tokenizer: AnimeTokenizer,
    max_length: int,
) -> Tuple[List[str], np.ndarray, np.ndarray, int]:
    tokens = tokenizer.tokenize(filename)
    available = min(len(tokens), max_length - 2)
    used_tokens = tokens[:available]

    input_ids = [tokenizer.cls_token_id]
    input_ids.extend(tokenizer.convert_tokens_to_ids(used_tokens))
    input_ids.append(tokenizer.sep_token_id)
    attention_mask = [1] * len(input_ids)

    pad_len = max_length - len(input_ids)
    if pad_len > 0:
        input_ids.extend([tokenizer.pad_token_id] * pad_len)
        attention_mask.extend([0] * pad_len)

    return (
        used_tokens,
        np.asarray([input_ids], dtype=np.int64),
        np.asarray([attention_mask], dtype=np.int64),
        available,
    )


def load_id2label(model_dir: Path) -> Dict[int, str]:
    config = json.loads((model_dir / "config.json").read_text(encoding="utf-8"))
    return {int(label_id): label for label_id, label in config["id2label"].items()}


def parse_with_onnx(
    filename: str,
    model_dir: Path,
    onnx_path: Path,
    max_length: int,
) -> Dict:
    parser = OnnxFilenameParser(model_dir, onnx_path, max_length)
    return parser.parse(filename)


class OnnxFilenameParser:
    """Reusable ONNX Runtime parser with tokenizer and session loaded once."""

    def __init__(
        self,
        model_dir: Path,
        onnx_path: Path,
        max_length: int,
        providers: List[str] | None = None,
        session_options: Optional[ort.SessionOptions] = None,
    ) -> None:
        self.model_dir = model_dir
        self.onnx_path = onnx_path
        self.max_length = max_length
        self.tokenizer = load_tokenizer(str(model_dir))
        self.id2label = load_id2label(model_dir)
        self.session = ort.InferenceSession(
            str(onnx_path),
            sess_options=session_options,
            providers=providers or ["CPUExecutionProvider"],
        )

    def parse(self, filename: str) -> Dict:
        tokens, input_ids, attention_mask, available = encode(filename, self.tokenizer, self.max_length)
        logits = self.session.run(
            ["logits"],
            {
                "input_ids": input_ids,
                "attention_mask": attention_mask,
            },
        )[0]

        token_logits = torch.from_numpy(logits[0, 1:1 + available, :])
        label_ids = constrained_bio_decode(token_logits, self.id2label)
        labels = [self.id2label.get(label_id, "O") for label_id in label_ids]
        result = postprocess(tokens, labels, tokenizer=self.tokenizer)
        result["_input"] = filename
        return result


def main() -> None:
    parser = argparse.ArgumentParser(description="Run AniFileBERT ONNX inference")
    parser.add_argument("filename", help="Anime filename to parse")
    parser.add_argument("--model-dir", default=".", help="Directory containing vocab.json and config.json")
    parser.add_argument("--onnx", default="exports/anime_filename_parser.onnx", help="ONNX model path")
    parser.add_argument("--max-length", type=int, default=128, help="Static ONNX sequence length")
    args = parser.parse_args()

    result = parse_with_onnx(
        filename=args.filename,
        model_dir=Path(args.model_dir),
        onnx_path=Path(args.onnx),
        max_length=args.max_length,
    )
    print(json.dumps(result, ensure_ascii=False))


if __name__ == "__main__":
    main()