|
|
|
|
|
import torch |
|
|
from PIL import Image |
|
|
from transformers import CLIPProcessor, CLIPModel |
|
|
|
|
|
def classify_image(image_path, text_labels): |
|
|
""" |
|
|
Classifies an image based on a list of text labels using a CLIP model. |
|
|
""" |
|
|
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") |
|
|
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") |
|
|
|
|
|
image = Image.open(image_path) |
|
|
inputs = processor(text=text_labels, images=image, return_tensors="pt", padding=True) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
|
|
|
logits_per_image = outputs.logits_per_image |
|
|
probs = logits_per_image.softmax(dim=1) |
|
|
|
|
|
return dict(zip(text_labels, probs.tolist()[0])) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
dummy_image = Image.new('RGB', (100, 100), color = 'red') |
|
|
dummy_image.save("dummy_image.png") |
|
|
|
|
|
labels = ["a red square", "a blue circle", "a green triangle"] |
|
|
probabilities = classify_image("dummy_image.png", labels) |
|
|
print("Probabilities:", probabilities) |
|
|
print("Predicted label:", max(probabilities, key=probabilities.get)) |
|
|
|
|
|
|