Update app.py
Browse files
app.py
CHANGED
|
@@ -1,68 +1,14 @@
|
|
| 1 |
-
from transformers import pipeline
|
| 2 |
-
|
| 3 |
-
checkpoint = "openai/clip-vit-large-patch14"
|
| 4 |
-
detector = pipeline(model=checkpoint, task="zero-shot-image-classification")
|
| 5 |
-
#checkpoint = "google/siglip-so400m-patch14-384"
|
| 6 |
-
#detector = pipeline(task="zero-shot-image-classification", model="google/siglip-so400m-patch14-384")
|
| 7 |
-
|
| 8 |
-
from datasets import load_dataset
|
| 9 |
-
|
| 10 |
-
dataset = load_dataset('pcuenq/oxford-pets')
|
| 11 |
-
dataset
|
| 12 |
-
|
| 13 |
-
dataset['train'][0]['image']
|
| 14 |
-
|
| 15 |
-
from PIL import Image
|
| 16 |
-
import io
|
| 17 |
-
from tqdm import tqdm
|
| 18 |
-
|
| 19 |
-
labels_oxford_pets = ['Siamese', 'Birman', 'shiba inu', 'staffordshire bull terrier', 'basset hound', 'Bombay', 'japanese chin', 'chihuahua', 'german shorthaired', 'pomeranian', 'beagle', 'english cocker spaniel', 'american pit bull terrier', 'Ragdoll', 'Persian', 'Egyptian Mau', 'miniature pinscher', 'Sphynx', 'Maine Coon', 'keeshond', 'yorkshire terrier', 'havanese', 'leonberger', 'wheaten terrier', 'american bulldog', 'english setter', 'boxer', 'newfoundland', 'Bengal', 'samoyed', 'British Shorthair', 'great pyrenees', 'Abyssinian', 'pug', 'saint bernard', 'Russian Blue', 'scottish terrier']
|
| 20 |
-
|
| 21 |
-
# List to store true labels and predicted labels
|
| 22 |
-
true_labels = []
|
| 23 |
-
predicted_labels = []
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
for i in tqdm(range(len(dataset['train']))):
|
| 27 |
-
# Get the image bytes from the dataset
|
| 28 |
-
image_bytes = dataset['train'][i]['image']['bytes']
|
| 29 |
-
|
| 30 |
-
# Convert the bytes to a PIL image
|
| 31 |
-
image = Image.open(io.BytesIO(image_bytes))
|
| 32 |
-
|
| 33 |
-
# Run the detector on the image with the provided labels
|
| 34 |
-
results = detector(image, candidate_labels=labels_oxford_pets)
|
| 35 |
-
# Sort the results by score in descending order
|
| 36 |
-
sorted_results = sorted(results, key=lambda x: x['score'], reverse=True)
|
| 37 |
-
|
| 38 |
-
# Get the top predicted label
|
| 39 |
-
predicted_label = sorted_results[0]['label']
|
| 40 |
-
|
| 41 |
-
# Append the true and predicted labels to the respective lists
|
| 42 |
-
true_labels.append(dataset['train'][i]['label'])
|
| 43 |
-
predicted_labels.append(predicted_label)
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
from sklearn.metrics import accuracy_score, precision_score, recall_score
|
| 47 |
-
|
| 48 |
-
# Calculate accuracy
|
| 49 |
-
accuracy = accuracy_score(true_labels, predicted_labels)
|
| 50 |
-
|
| 51 |
-
# Calculate precision and recall
|
| 52 |
-
precision = precision_score(true_labels, predicted_labels, average='weighted', labels=labels_oxford_pets)
|
| 53 |
-
recall = recall_score(true_labels, predicted_labels, average='weighted', labels=labels_oxford_pets)
|
| 54 |
-
|
| 55 |
-
# Print the results
|
| 56 |
-
print(f"Accuracy: {accuracy:.4f}")
|
| 57 |
-
print(f"Precision: {precision:.4f}")
|
| 58 |
-
print(f"Recall: {recall:.4f}")
|
| 59 |
-
|
| 60 |
import gradio as gr
|
| 61 |
from transformers import pipeline
|
| 62 |
|
| 63 |
-
#
|
| 64 |
-
vit_classifier = pipeline("image-classification", model="
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
labels_oxford_pets = [
|
| 68 |
'Siamese', 'Birman', 'shiba inu', 'staffordshire bull terrier', 'basset hound', 'Bombay', 'japanese chin',
|
|
@@ -76,10 +22,13 @@ def classify_pet(image):
|
|
| 76 |
vit_results = vit_classifier(image)
|
| 77 |
vit_output = {result['label']: result['score'] for result in vit_results}
|
| 78 |
|
| 79 |
-
|
| 80 |
-
|
| 81 |
|
| 82 |
-
return {
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
example_images = [
|
| 85 |
["example_images/dog1.jpeg"],
|
|
@@ -94,175 +43,8 @@ iface = gr.Interface(
|
|
| 94 |
inputs=gr.Image(type="filepath"),
|
| 95 |
outputs=gr.JSON(),
|
| 96 |
title="Pet Classification Comparison",
|
| 97 |
-
description="Upload an image of a pet, and compare results from a trained ViT model and a zero-shot
|
| 98 |
examples=example_images
|
| 99 |
)
|
| 100 |
|
| 101 |
-
iface.launch()
|
| 102 |
-
|
| 103 |
-
import io
|
| 104 |
-
from PIL import Image
|
| 105 |
-
|
| 106 |
-
import numpy as np
|
| 107 |
-
import pandas as pd
|
| 108 |
-
import matplotlib.pyplot as plt
|
| 109 |
-
|
| 110 |
-
import torch
|
| 111 |
-
import torch.nn as nn
|
| 112 |
-
|
| 113 |
-
from huggingface_hub import notebook_login
|
| 114 |
-
|
| 115 |
-
from datasets import load_dataset, DatasetDict
|
| 116 |
-
|
| 117 |
-
from transformers import AutoImageProcessor, ViTForImageClassification
|
| 118 |
-
|
| 119 |
-
from transformers import Trainer, TrainingArguments
|
| 120 |
-
|
| 121 |
-
import evaluate
|
| 122 |
-
|
| 123 |
-
# Login onto Hugging Face hub to load any private dataset/model.
|
| 124 |
-
# We need to login as we'll also upload our model to the hub
|
| 125 |
-
notebook_login()
|
| 126 |
-
|
| 127 |
-
dataset = load_dataset('pcuenq/oxford-pets')
|
| 128 |
-
dataset
|
| 129 |
-
|
| 130 |
-
dataset['train'][0]
|
| 131 |
-
|
| 132 |
-
labels = dataset['train'].unique('label')
|
| 133 |
-
print(len(labels),labels)
|
| 134 |
-
|
| 135 |
-
def show_samples(ds,rows,cols):
|
| 136 |
-
samples = ds.shuffle().select(np.arange(rows*cols)) # selecting random images
|
| 137 |
-
fig = plt.figure(figsize=(cols*4,rows*4))
|
| 138 |
-
# plotting
|
| 139 |
-
for i in range(rows*cols):
|
| 140 |
-
img_bytes = samples[i]['image']['bytes']
|
| 141 |
-
img = Image.open(io.BytesIO(img_bytes))
|
| 142 |
-
label = samples[i]['label']
|
| 143 |
-
fig.add_subplot(rows,cols,i+1)
|
| 144 |
-
plt.imshow(img)
|
| 145 |
-
plt.title(label)
|
| 146 |
-
plt.axis('off')
|
| 147 |
-
|
| 148 |
-
show_samples(dataset['train'],rows=3,cols=5)
|
| 149 |
-
|
| 150 |
-
split_dataset = dataset['train'].train_test_split(test_size=0.2) # 80% train, 20% evaluation
|
| 151 |
-
eval_dataset = split_dataset['test'].train_test_split(test_size=0.5) # 50% validation, 50% test
|
| 152 |
-
|
| 153 |
-
# recombining the splits using a DatasetDict
|
| 154 |
-
|
| 155 |
-
our_dataset = DatasetDict({
|
| 156 |
-
'train': split_dataset['train'],
|
| 157 |
-
'validation': eval_dataset['train'],
|
| 158 |
-
'test': eval_dataset['test']
|
| 159 |
-
})
|
| 160 |
-
|
| 161 |
-
our_dataset
|
| 162 |
-
|
| 163 |
-
label2id = {c:idx for idx,c in enumerate(labels)}
|
| 164 |
-
id2label = {idx:c for idx,c in enumerate(labels)}
|
| 165 |
-
|
| 166 |
-
processor = AutoImageProcessor.from_pretrained('google/vit-base-patch16-224')
|
| 167 |
-
processor
|
| 168 |
-
|
| 169 |
-
def transforms(batch):
|
| 170 |
-
batch['image'] = [Image.open(io.BytesIO(x['bytes'])).convert('RGB') for x in batch['image']]
|
| 171 |
-
inputs = processor(batch['image'],return_tensors='pt')
|
| 172 |
-
inputs['labels']=[label2id[y] for y in batch['label']]
|
| 173 |
-
return inputs
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
processed_dataset = our_dataset.with_transform(transforms)
|
| 178 |
-
|
| 179 |
-
processed_dataset
|
| 180 |
-
|
| 181 |
-
def collate_fn(batch):
|
| 182 |
-
return {
|
| 183 |
-
'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
|
| 184 |
-
'labels': torch.tensor([x['labels'] for x in batch])
|
| 185 |
-
}
|
| 186 |
-
|
| 187 |
-
accuracy = evaluate.load('accuracy')
|
| 188 |
-
def compute_metrics(eval_preds):
|
| 189 |
-
logits, labels = eval_preds
|
| 190 |
-
predictions = np.argmax(logits,axis=1)
|
| 191 |
-
score = accuracy.compute(predictions=predictions, references=labels)
|
| 192 |
-
return score
|
| 193 |
-
|
| 194 |
-
model = ViTForImageClassification.from_pretrained(
|
| 195 |
-
'google/vit-base-patch16-224',
|
| 196 |
-
num_labels = len(labels),
|
| 197 |
-
id2label = id2label,
|
| 198 |
-
label2id = label2id,
|
| 199 |
-
ignore_mismatched_sizes = True
|
| 200 |
-
)
|
| 201 |
-
|
| 202 |
-
model
|
| 203 |
-
|
| 204 |
-
for name,p in model.named_parameters():
|
| 205 |
-
if not name.startswith('classifier'):
|
| 206 |
-
p.requires_grad = False
|
| 207 |
-
|
| 208 |
-
num_params = sum([p.numel() for p in model.parameters()])
|
| 209 |
-
trainable_params = sum([p.numel() for p in model.parameters() if p.requires_grad])
|
| 210 |
-
|
| 211 |
-
print(f"{num_params = :,} | {trainable_params = :,}")
|
| 212 |
-
|
| 213 |
-
training_args = TrainingArguments(
|
| 214 |
-
output_dir="./vit-base-oxford-iiit-pets",
|
| 215 |
-
per_device_train_batch_size=16,
|
| 216 |
-
evaluation_strategy="epoch",
|
| 217 |
-
save_strategy="epoch",
|
| 218 |
-
logging_steps=100,
|
| 219 |
-
num_train_epochs=5,
|
| 220 |
-
learning_rate=3e-4,
|
| 221 |
-
save_total_limit=2,
|
| 222 |
-
remove_unused_columns=False,
|
| 223 |
-
push_to_hub=True,
|
| 224 |
-
report_to='tensorboard',
|
| 225 |
-
load_best_model_at_end=True,
|
| 226 |
-
run_name="transferlearning-google_vit-base-patch16-224 no data augmentation"
|
| 227 |
-
)
|
| 228 |
-
|
| 229 |
-
trainer = Trainer(
|
| 230 |
-
model=model,
|
| 231 |
-
args=training_args,
|
| 232 |
-
data_collator=collate_fn,
|
| 233 |
-
compute_metrics=compute_metrics,
|
| 234 |
-
train_dataset=processed_dataset["train"],
|
| 235 |
-
eval_dataset=processed_dataset["validation"],
|
| 236 |
-
tokenizer=processor
|
| 237 |
-
)
|
| 238 |
-
|
| 239 |
-
trainer.train()
|
| 240 |
-
|
| 241 |
-
trainer.evaluate(processed_dataset['test'])
|
| 242 |
-
|
| 243 |
-
def show_predictions(rows,cols):
|
| 244 |
-
samples = our_dataset['test'].shuffle().select(np.arange(rows*cols))
|
| 245 |
-
processed_samples = samples.with_transform(transforms)
|
| 246 |
-
predictions = trainer.predict(processed_samples).predictions.argmax(axis=1) # predicted labels from logits
|
| 247 |
-
fig = plt.figure(figsize=(cols*4,rows*4))
|
| 248 |
-
for i in range(rows*cols):
|
| 249 |
-
img_bytes = samples[i]['image']['bytes']
|
| 250 |
-
img = Image.open(io.BytesIO(img_bytes))
|
| 251 |
-
prediction = predictions[i]
|
| 252 |
-
label = f"label: {samples[i]['label']}\npredicted: {id2label[prediction]}"
|
| 253 |
-
fig.add_subplot(rows,cols,i+1)
|
| 254 |
-
plt.imshow(img)
|
| 255 |
-
plt.title(label)
|
| 256 |
-
plt.axis('off')
|
| 257 |
-
|
| 258 |
-
show_predictions(rows=5,cols=5)
|
| 259 |
-
|
| 260 |
-
kwargs = {
|
| 261 |
-
"finetuned_from": model.config._name_or_path,
|
| 262 |
-
"dataset": 'pcuenq/oxford-pets',
|
| 263 |
-
"tasks": "image-classification",
|
| 264 |
-
"tags": ['image-classification'],
|
| 265 |
-
}
|
| 266 |
-
|
| 267 |
-
trainer.save_model()
|
| 268 |
-
trainer.push_to_hub('🐕️🐈️', **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
from transformers import pipeline
|
| 3 |
|
| 4 |
+
# ViT-Modell für klassische Klassifikation
|
| 5 |
+
vit_classifier = pipeline("image-classification", model="dewiri/vit-base-oxford-iiit-pets")
|
| 6 |
+
|
| 7 |
+
# SIGLIP-Modell für Zero-Shot Klassifikation (anstelle von CLIP)
|
| 8 |
+
siglip_detector = pipeline(
|
| 9 |
+
model="google/siglip-so400m-patch14-384",
|
| 10 |
+
task="zero-shot-image-classification"
|
| 11 |
+
)
|
| 12 |
|
| 13 |
labels_oxford_pets = [
|
| 14 |
'Siamese', 'Birman', 'shiba inu', 'staffordshire bull terrier', 'basset hound', 'Bombay', 'japanese chin',
|
|
|
|
| 22 |
vit_results = vit_classifier(image)
|
| 23 |
vit_output = {result['label']: result['score'] for result in vit_results}
|
| 24 |
|
| 25 |
+
siglip_results = siglip_detector(image, candidate_labels=labels_oxford_pets)
|
| 26 |
+
siglip_output = {result['label']: result['score'] for result in siglip_results}
|
| 27 |
|
| 28 |
+
return {
|
| 29 |
+
"ViT Classification": vit_output,
|
| 30 |
+
"SIGLIP Zero-Shot Classification": siglip_output
|
| 31 |
+
}
|
| 32 |
|
| 33 |
example_images = [
|
| 34 |
["example_images/dog1.jpeg"],
|
|
|
|
| 43 |
inputs=gr.Image(type="filepath"),
|
| 44 |
outputs=gr.JSON(),
|
| 45 |
title="Pet Classification Comparison",
|
| 46 |
+
description="Upload an image of a pet, and compare results from a trained ViT model and a zero-shot SIGLIP model.",
|
| 47 |
examples=example_images
|
| 48 |
)
|
| 49 |
|
| 50 |
+
iface.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|