Spaces:
Sleeping
Sleeping
Upload zero_shot_classification.py
Browse files- zero_shot_classification.py +90 -0
zero_shot_classification.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import pipeline
|
| 2 |
+
from datasets import load_dataset
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import io
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
from sklearn.metrics import accuracy_score, precision_score, recall_score
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
# Clear the dataset cache
|
| 10 |
+
cache_dir = os.path.expanduser("~/.cache/huggingface/datasets")
|
| 11 |
+
if os.path.exists(cache_dir):
|
| 12 |
+
import shutil
|
| 13 |
+
shutil.rmtree(cache_dir)
|
| 14 |
+
|
| 15 |
+
# Load the CLIP model for zero-shot classification
|
| 16 |
+
print("Loading CLIP model...")
|
| 17 |
+
checkpoint = "openai/clip-vit-large-patch14"
|
| 18 |
+
detector = pipeline(model=checkpoint, task="zero-shot-image-classification")
|
| 19 |
+
|
| 20 |
+
# Load the Oxford Pets dataset
|
| 21 |
+
print("Loading Oxford Pets dataset...")
|
| 22 |
+
try:
|
| 23 |
+
# Only use first 100 images for faster testing
|
| 24 |
+
dataset = load_dataset('pcuenq/oxford-pets', split='train[:100]')
|
| 25 |
+
print(f"Loaded {len(dataset)} images")
|
| 26 |
+
except Exception as e:
|
| 27 |
+
print(f"Error loading dataset: {e}")
|
| 28 |
+
exit(1)
|
| 29 |
+
|
| 30 |
+
# Define the labels for Oxford Pets
|
| 31 |
+
labels_oxford_pets = [
|
| 32 |
+
'Siamese', 'Birman', 'shiba inu', 'staffordshire bull terrier', 'basset hound', 'Bombay', 'japanese chin',
|
| 33 |
+
'chihuahua', 'german shorthaired', 'pomeranian', 'beagle', 'english cocker spaniel', 'american pit bull terrier',
|
| 34 |
+
'Ragdoll', 'Persian', 'Egyptian Mau', 'miniature pinscher', 'Sphynx', 'Maine Coon', 'keeshond', 'yorkshire terrier',
|
| 35 |
+
'havanese', 'leonberger', 'wheaten terrier', 'american bulldog', 'english setter', 'boxer', 'newfoundland', 'Bengal',
|
| 36 |
+
'samoyed', 'British Shorthair', 'great pyrenees', 'Abyssinian', 'pug', 'saint bernard', 'Russian Blue', 'scottish terrier'
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
# Lists to store true and predicted labels
|
| 40 |
+
true_labels = []
|
| 41 |
+
predicted_labels = []
|
| 42 |
+
|
| 43 |
+
print("Processing images...")
|
| 44 |
+
for i in tqdm(range(len(dataset)), desc="Processing images"):
|
| 45 |
+
try:
|
| 46 |
+
# Get the image bytes from the dataset
|
| 47 |
+
image_bytes = dataset[i]['image']['bytes']
|
| 48 |
+
|
| 49 |
+
# Convert the bytes to a PIL image
|
| 50 |
+
image = Image.open(io.BytesIO(image_bytes))
|
| 51 |
+
|
| 52 |
+
# Run the detector on the image with the provided labels
|
| 53 |
+
results = detector(image, candidate_labels=labels_oxford_pets)
|
| 54 |
+
# Sort the results by score in descending order
|
| 55 |
+
sorted_results = sorted(results, key=lambda x: x['score'], reverse=True)
|
| 56 |
+
|
| 57 |
+
# Get the top predicted label
|
| 58 |
+
predicted_label = sorted_results[0]['label']
|
| 59 |
+
|
| 60 |
+
# Append the true and predicted labels to the respective lists
|
| 61 |
+
true_labels.append(dataset[i]['label'])
|
| 62 |
+
predicted_labels.append(predicted_label)
|
| 63 |
+
|
| 64 |
+
# Print progress every 10 images
|
| 65 |
+
if (i + 1) % 10 == 0:
|
| 66 |
+
print(f"Processed {i + 1}/{len(dataset)} images")
|
| 67 |
+
|
| 68 |
+
except Exception as e:
|
| 69 |
+
print(f"Error processing image {i}: {e}")
|
| 70 |
+
continue
|
| 71 |
+
|
| 72 |
+
# Calculate metrics
|
| 73 |
+
accuracy = accuracy_score(true_labels, predicted_labels)
|
| 74 |
+
precision = precision_score(true_labels, predicted_labels, average='weighted', labels=labels_oxford_pets)
|
| 75 |
+
recall = recall_score(true_labels, predicted_labels, average='weighted', labels=labels_oxford_pets)
|
| 76 |
+
|
| 77 |
+
# Print and save results
|
| 78 |
+
results = f"""
|
| 79 |
+
Zero-Shot Classification Results using CLIP (openai/clip-vit-large-patch14)
|
| 80 |
+
====================================================================
|
| 81 |
+
Accuracy: {accuracy:.4f}
|
| 82 |
+
Precision: {precision:.4f}
|
| 83 |
+
Recall: {recall:.4f}
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
print(results)
|
| 87 |
+
|
| 88 |
+
# Save results to a file
|
| 89 |
+
with open('zero_shot_results.md', 'w') as f:
|
| 90 |
+
f.write(results)
|