import argparse import importlib import json import os import sys import numpy as np import onnxruntime as ort from huggingface_hub import snapshot_download from PIL import Image def resolve_model_dir(args): if args.model_dir: return os.path.abspath(args.model_dir) if args.repo_id: return snapshot_download(repo_id=args.repo_id) return os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) def load_processors(model_dir): sys.path.insert(0, model_dir) tokenizer_config_path = os.path.join(model_dir, "tokenizer_config.json") with open(tokenizer_config_path, "r", encoding="utf-8") as f: tokenizer_config = json.load(f) GLMChineseTokenizer = importlib.import_module("tokenization_glm").GLMChineseTokenizer M2EncoderImageProcessor = importlib.import_module("image_processing_m2_encoder").M2EncoderImageProcessor tokenizer = GLMChineseTokenizer( vocab_file=os.path.join(model_dir, "sp.model"), eos_token=tokenizer_config.get("eos_token"), pad_token=tokenizer_config.get("pad_token"), cls_token=tokenizer_config.get("cls_token"), mask_token=tokenizer_config.get("mask_token"), unk_token=tokenizer_config.get("unk_token"), ) image_processor = M2EncoderImageProcessor.from_pretrained(model_dir) return tokenizer, image_processor def softmax(x): x = x - np.max(x, axis=-1, keepdims=True) exp_x = np.exp(x) return exp_x / np.sum(exp_x, axis=-1, keepdims=True) def main(): parser = argparse.ArgumentParser(description="Run M2-Encoder ONNX inference.") parser.add_argument("--repo-id", help="Hugging Face repo id to download.") parser.add_argument("--model-dir", help="Local model directory. Defaults to this repo root.") parser.add_argument("--image", required=True, help="Local image path.") parser.add_argument( "--text", nargs="+", required=True, help="Candidate text labels. Example: --text 杰尼龟 妙蛙种子 小火龙 皮卡丘", ) args = parser.parse_args() model_dir = resolve_model_dir(args) tokenizer, image_processor = load_processors(model_dir) text_inputs = tokenizer( args.text, padding="max_length", truncation=True, max_length=52, return_special_tokens_mask=True, return_tensors="np", ) image_inputs = image_processor( Image.open(args.image).convert("RGB"), return_tensors="np", ) text_session = ort.InferenceSession( os.path.join(model_dir, "onnx", "text_encoder.onnx"), providers=["CPUExecutionProvider"], ) image_session = ort.InferenceSession( os.path.join(model_dir, "onnx", "image_encoder.onnx"), providers=["CPUExecutionProvider"], ) text_embeds = text_session.run( None, { "input_ids": text_inputs["input_ids"], "attention_mask": text_inputs["attention_mask"], }, )[0] image_embeds = image_session.run( None, {"pixel_values": image_inputs["pixel_values"]}, )[0] scores = image_embeds @ text_embeds.T probs = softmax(scores) ranked = [ { "label": label, "score": float(score), "prob": float(prob), } for label, score, prob in sorted( zip(args.text, scores[0].tolist(), probs[0].tolist()), key=lambda item: item[2], reverse=True, ) ] print(json.dumps({"ranked_results": ranked}, ensure_ascii=False, indent=2)) if __name__ == "__main__": main()