jina-clip-v2 / run_axmodel.py
jordan0811's picture
Upload run_axmodel.py with huggingface_hub
330a86b verified
import axengine as ort
import torch
import torch.nn.functional as f
from PIL import Image
from transformers import AutoModel, AutoTokenizer, BatchEncoding, AutoImageProcessor
import numpy as np
import argparse
def _truncate_embeddings(embeddings: torch.Tensor, truncate_dim: int):
return embeddings[:, :truncate_dim]
def run_image_encoder(image_encoder, pixel_values):
inputs = {image_encoder.get_inputs()[0].name: pixel_values}
embeddings = image_encoder.run(None, inputs)[0]
if truncate_dim:
embeddings = _truncate_embeddings(embeddings, truncate_dim)
embeddings = f.normalize(torch.from_numpy(embeddings), p=2, dim=1).numpy()
return embeddings
def run_text_encoder(text_encoder, input_ids):
inputs = {text_encoder.get_inputs()[0].name: input_ids}
embeddings = text_encoder.run(None, inputs)[0]
if truncate_dim:
embeddings = _truncate_embeddings(embeddings, truncate_dim)
embeddings = f.normalize(torch.from_numpy(embeddings), p=2, dim=1).numpy()
return embeddings
if __name__ == '__main__':
argparser = argparse.ArgumentParser()
argparser.add_argument('-i', '--image', type=str, default='beach1.jpg', help='Path to the image file')
argparser.add_argument('-t', '--text', type=str, default='beautiful sunset over the beach', help='Text to encode')
argparser.add_argument('-iax', '--image_axmodel', type=str, default='image_encoder.axmodel', help='Path to the image axmodel file')
argparser.add_argument('-tax', '--text_axmodel', type=str, default='text_encoder.axmodel', help='Path to the text axmodel file')
argparser.add_argument('--hf_path', type=str, default='/root/wangjian/hf_cache/jina-clip-v2', help='Path to the Hugging Face model cache')
args = argparser.parse_args()
image_encoder_path = args.image_axmodel
text_encoder_path = args.text_axmodel
hf_path = args.hf_path
image_encoder = ort.InferenceSession(image_encoder_path)
text_encoder = ort.InferenceSession(text_encoder_path)
tokenizer = AutoTokenizer.from_pretrained(hf_path, trust_remote_code=True)
preprocess = AutoImageProcessor.from_pretrained(hf_path, trust_remote_code=True)
# ============================== Text ======================================= #
# Choose a matryoshka dimension, set to None to get the full 1024-dim vectors
truncate_dim = 512
max_seq_length = 50
task = None # 'retrieval.query'
sentences = args.text # English
if isinstance(sentences, str) or not hasattr(sentences, '__len__'):
sentences = [sentences]
tokenizer_kwargs = {'padding': True, 'max_length': 512, 'truncation': True}
# from: /data/wangjian/project/hf_cache/jinaai/jina-clip-v2/config.json
_task_instructions = {
"retrieval.query": "Represent the query for retrieving evidence documents: "
}
if task is not None:
instruction = _task_instructions[task]
if instruction:
sentences = [instruction + sentence for sentence in sentences]
tokens = tokenizer(
sentences,
return_tensors='pt',
**tokenizer_kwargs,
)
pad_token = tokenizer.get_added_vocab()["<pad>"]
input_ids = torch.nn.functional.pad(tokens.input_ids, (0, max_seq_length - tokens.input_ids.shape[1]), value=pad_token)
text_embeddings = run_text_encoder(text_encoder, input_ids.numpy().astype(np.int32))
# ============================== Image ======================================= #
image_urls = [args.image]
_processed_images = []
for img in image_urls:
image = Image.open(img).convert('RGB')
_processed_images.append(image)
pixelvals = preprocess(_processed_images)
image_embeddings = run_image_encoder(image_encoder, pixelvals.pixel_values.numpy())
print("text -> image: " + str(text_embeddings[0] @ image_embeddings[0].T))