File size: 3,791 Bytes
39a7193 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import numpy as np
from PIL import Image
import axengine as ort
import torch
from torchvision.transforms import Normalize, Compose, InterpolationMode, ToTensor, Resize, CenterCrop
from tokenizer import SimpleTokenizer
import argparse
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
def image_transform_v2():
resolution = 256
resize_size = resolution
centercrop_size = resolution
mean = OPENAI_DATASET_MEAN
std = OPENAI_DATASET_STD
aug_list = [
Resize(
resize_size,
interpolation=InterpolationMode.BICUBIC,
),
CenterCrop(centercrop_size),
ToTensor(),
Normalize(mean=mean, std=std)
]
preprocess = Compose(aug_list)
return preprocess
def image_transform_v1():
resolution = 256
resize_size = resolution
centercrop_size = resolution
aug_list = [
Resize(
resize_size,
interpolation=InterpolationMode.BILINEAR,
),
CenterCrop(centercrop_size),
ToTensor(),
]
preprocess = Compose(aug_list)
return preprocess
def softmax(x, axis=-1):
"""
对 numpy 数组在指定维度上应用 softmax 函数
参数:
x: numpy 数组,输入数据
axis: 计算 softmax 的维度,默认为最后一个维度 (-1)
返回:
经过 softmax 处理的 numpy 数组,与输入形状相同
"""
# 减去最大值以防止数值溢出(数值稳定化)
e_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
# 计算每个元素的指数与所在维度总和的比值
return e_x / np.sum(e_x, axis=axis, keepdims=True)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-ie", "--image_encoder_path", type=str, default="./mobileclip2_s4_image_encoder.axmodel",
help="image encoder axmodel path")
parser.add_argument("-te", "--text_encoder_path", type=str, default="./mobileclip2_s4_text_encoder.axmodel",
help="text encoder axmodel path")
parser.add_argument("-i", "--image", type=str, default="./zebra.jpg",
help="input image path")
parser.add_argument("-t", "--class_text", type=str, nargs='+', default=["a zebra", "a dog", "two zebras"],
help='List of captions, e.g.: "a zebra" "a dog" "two zebras"')
args = parser.parse_args()
image_encoder_path = args.image_encoder_path
text_encoder_path = args.text_encoder_path
# NOTICE: 使用v1的预处理,v2的预处理方式在pulsar2中量化误差比较大
preprocess = image_transform_v1()
tokenizer = SimpleTokenizer(context_length=77)
image = preprocess(Image.open(args.image).convert('RGB')).unsqueeze(0)
text = tokenizer(args.class_text)
text = text.to(torch.int32)
onnx_image_encoder = ort.InferenceSession(image_encoder_path)
onnx_text_encoder = ort.InferenceSession(text_encoder_path)
image_features = onnx_image_encoder.run(["unnorm_image_features"],{"image":np.array(image)})[0]
# text_features = []
# for i in range(text.shape[0]):
# text_feature = onnx_text_encoder.run(["unnorm_text_features"],{"text":np.array([text[i]])})[0]
# text_features.append(text_feature)
# text_features = np.array([t[0] for t in text_features])
text_features = onnx_text_encoder.run(["unnorm_text_features"], {"text": text.numpy()})[0]
image_features /= np.linalg.norm(image_features, ord=2, axis=-1, keepdims=True)
text_features /= np.linalg.norm(text_features, ord=2, axis=-1, keepdims=True)
text_probs = softmax(100.0 * image_features @ text_features.T)
print("Label probs:", text_probs) |