| import json | |
| from clip_for_image_classification import FlaxCLIPForImageClassification | |
| from PIL import Image | |
| import jax | |
| import numpy as np | |
| from transformers import CLIPImageProcessor | |
| import os | |
| os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" | |
| model = FlaxCLIPForImageClassification.from_pretrained("Thouph/clip-vit-l-224-patch14-datacomp-image-classification") | |
| image_processor = CLIPImageProcessor.from_pretrained("laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K") | |
| image = Image.open("/your/image/here.jpg") | |
| inputs = image_processor(images=image, return_tensors="jax") | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| probabilities = jax.nn.sigmoid(logits) | |
| probabilities = np.asarray(probabilities).copy() | |
| def topk_by_sort(input, k, axis=None, ascending=False): | |
| if not ascending: | |
| input *= -1 | |
| ind = np.argsort(input, axis=axis) | |
| ind = np.take(ind, np.arange(k), axis=axis) | |
| if not ascending: | |
| input *= -1 | |
| val = np.take_along_axis(input, ind, axis=axis) | |
| return ind, val | |
| indices, values = topk_by_sort(probabilities, 100) | |
| with open("7748tags.json", "r") as file: | |
| allowed_tags = json.load(file) | |
| for index, value in zip(indices, values, strict=True): | |
| print(allowed_tags[index], value) | |