File size: 3,625 Bytes
8b09a83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import argparse
import importlib
import json
import os
import sys

import numpy as np
import onnxruntime as ort
from huggingface_hub import snapshot_download
from PIL import Image


def resolve_model_dir(args):
    if args.model_dir:
        return os.path.abspath(args.model_dir)
    if args.repo_id:
        return snapshot_download(repo_id=args.repo_id)
    return os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))


def load_processors(model_dir):
    sys.path.insert(0, model_dir)
    tokenizer_config_path = os.path.join(model_dir, "tokenizer_config.json")
    with open(tokenizer_config_path, "r", encoding="utf-8") as f:
        tokenizer_config = json.load(f)

    GLMChineseTokenizer = importlib.import_module("tokenization_glm").GLMChineseTokenizer
    M2EncoderImageProcessor = importlib.import_module("image_processing_m2_encoder").M2EncoderImageProcessor

    tokenizer = GLMChineseTokenizer(
        vocab_file=os.path.join(model_dir, "sp.model"),
        eos_token=tokenizer_config.get("eos_token"),
        pad_token=tokenizer_config.get("pad_token"),
        cls_token=tokenizer_config.get("cls_token"),
        mask_token=tokenizer_config.get("mask_token"),
        unk_token=tokenizer_config.get("unk_token"),
    )
    image_processor = M2EncoderImageProcessor.from_pretrained(model_dir)
    return tokenizer, image_processor


def softmax(x):
    x = x - np.max(x, axis=-1, keepdims=True)
    exp_x = np.exp(x)
    return exp_x / np.sum(exp_x, axis=-1, keepdims=True)


def main():
    parser = argparse.ArgumentParser(description="Run M2-Encoder ONNX inference.")
    parser.add_argument("--repo-id", help="Hugging Face repo id to download.")
    parser.add_argument("--model-dir", help="Local model directory. Defaults to this repo root.")
    parser.add_argument("--image", required=True, help="Local image path.")
    parser.add_argument(
        "--text",
        nargs="+",
        required=True,
        help="Candidate text labels. Example: --text 杰尼龟 妙蛙种子 小火龙 皮卡丘",
    )
    args = parser.parse_args()

    model_dir = resolve_model_dir(args)
    tokenizer, image_processor = load_processors(model_dir)

    text_inputs = tokenizer(
        args.text,
        padding="max_length",
        truncation=True,
        max_length=52,
        return_special_tokens_mask=True,
        return_tensors="np",
    )
    image_inputs = image_processor(
        Image.open(args.image).convert("RGB"),
        return_tensors="np",
    )

    text_session = ort.InferenceSession(
        os.path.join(model_dir, "onnx", "text_encoder.onnx"),
        providers=["CPUExecutionProvider"],
    )
    image_session = ort.InferenceSession(
        os.path.join(model_dir, "onnx", "image_encoder.onnx"),
        providers=["CPUExecutionProvider"],
    )

    text_embeds = text_session.run(
        None,
        {
            "input_ids": text_inputs["input_ids"],
            "attention_mask": text_inputs["attention_mask"],
        },
    )[0]
    image_embeds = image_session.run(
        None,
        {"pixel_values": image_inputs["pixel_values"]},
    )[0]

    scores = image_embeds @ text_embeds.T
    probs = softmax(scores)
    ranked = [
        {
            "label": label,
            "score": float(score),
            "prob": float(prob),
        }
        for label, score, prob in sorted(
            zip(args.text, scores[0].tolist(), probs[0].tolist()),
            key=lambda item: item[2],
            reverse=True,
        )
    ]
    print(json.dumps({"ranked_results": ranked}, ensure_ascii=False, indent=2))


if __name__ == "__main__":
    main()