| import axengine as ort |
| import torch |
| import torch.nn.functional as f |
| from PIL import Image |
| from transformers import AutoModel, AutoTokenizer, BatchEncoding, AutoImageProcessor |
| import numpy as np |
| import argparse |
|
|
|
|
| def _truncate_embeddings(embeddings: torch.Tensor, truncate_dim: int): |
| return embeddings[:, :truncate_dim] |
|
|
|
|
| def run_image_encoder(image_encoder, pixel_values): |
| inputs = {image_encoder.get_inputs()[0].name: pixel_values} |
| embeddings = image_encoder.run(None, inputs)[0] |
| if truncate_dim: |
| embeddings = _truncate_embeddings(embeddings, truncate_dim) |
| embeddings = f.normalize(torch.from_numpy(embeddings), p=2, dim=1).numpy() |
| return embeddings |
|
|
|
|
| def run_text_encoder(text_encoder, input_ids): |
| inputs = {text_encoder.get_inputs()[0].name: input_ids} |
| embeddings = text_encoder.run(None, inputs)[0] |
| if truncate_dim: |
| embeddings = _truncate_embeddings(embeddings, truncate_dim) |
| embeddings = f.normalize(torch.from_numpy(embeddings), p=2, dim=1).numpy() |
| return embeddings |
|
|
|
|
| if __name__ == '__main__': |
| argparser = argparse.ArgumentParser() |
| argparser.add_argument('-i', '--image', type=str, default='beach1.jpg', help='Path to the image file') |
| argparser.add_argument('-t', '--text', type=str, default='beautiful sunset over the beach', help='Text to encode') |
| argparser.add_argument('-iax', '--image_axmodel', type=str, default='image_encoder.axmodel', help='Path to the image axmodel file') |
| argparser.add_argument('-tax', '--text_axmodel', type=str, default='text_encoder.axmodel', help='Path to the text axmodel file') |
| argparser.add_argument('--hf_path', type=str, default='/root/wangjian/hf_cache/jina-clip-v2', help='Path to the Hugging Face model cache') |
| args = argparser.parse_args() |
| |
| image_encoder_path = args.image_axmodel |
| text_encoder_path = args.text_axmodel |
| hf_path = args.hf_path |
| |
| image_encoder = ort.InferenceSession(image_encoder_path) |
| text_encoder = ort.InferenceSession(text_encoder_path) |
| |
| tokenizer = AutoTokenizer.from_pretrained(hf_path, trust_remote_code=True) |
| preprocess = AutoImageProcessor.from_pretrained(hf_path, trust_remote_code=True) |
| |
| |
| |
| truncate_dim = 512 |
| max_seq_length = 50 |
| task = None |
| sentences = args.text |
| |
| if isinstance(sentences, str) or not hasattr(sentences, '__len__'): |
| sentences = [sentences] |
| |
| tokenizer_kwargs = {'padding': True, 'max_length': 512, 'truncation': True} |
| |
| |
| _task_instructions = { |
| "retrieval.query": "Represent the query for retrieving evidence documents: " |
| } |
| if task is not None: |
| instruction = _task_instructions[task] |
| if instruction: |
| sentences = [instruction + sentence for sentence in sentences] |
| |
| tokens = tokenizer( |
| sentences, |
| return_tensors='pt', |
| **tokenizer_kwargs, |
| ) |
| |
| pad_token = tokenizer.get_added_vocab()["<pad>"] |
| input_ids = torch.nn.functional.pad(tokens.input_ids, (0, max_seq_length - tokens.input_ids.shape[1]), value=pad_token) |
|
|
| text_embeddings = run_text_encoder(text_encoder, input_ids.numpy().astype(np.int32)) |
| |
| |
| image_urls = [args.image] |
| _processed_images = [] |
| for img in image_urls: |
| image = Image.open(img).convert('RGB') |
| _processed_images.append(image) |
|
|
| pixelvals = preprocess(_processed_images) |
|
|
| image_embeddings = run_image_encoder(image_encoder, pixelvals.pixel_values.numpy()) |
| |
| print("text -> image: " + str(text_embeddings[0] @ image_embeddings[0].T)) |
|
|