import torch from PIL import Image from transformers import CLIPProcessor, CLIPModel def classify_image(image_path, text_labels): """ Classifies an image based on a list of text labels using a CLIP model. """ model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") image = Image.open(image_path) inputs = processor(text=text_labels, images=image, return_tensors="pt", padding=True) with torch.no_grad(): outputs = model(**inputs) logits_per_image = outputs.logits_per_image probs = logits_per_image.softmax(dim=1) return dict(zip(text_labels, probs.tolist()[0])) if __name__ == '__main__': # Create a dummy image for testing dummy_image = Image.new('RGB', (100, 100), color = 'red') dummy_image.save("dummy_image.png") labels = ["a red square", "a blue circle", "a green triangle"] probabilities = classify_image("dummy_image.png", labels) print("Probabilities:", probabilities) print("Predicted label:", max(probabilities, key=probabilities.get))