| import os
|
| from pathlib import Path
|
|
|
| import gradio as gr
|
| from transformers import pipeline
|
|
|
|
|
|
|
|
|
|
|
|
|
| BASE_DIR = Path(__file__).resolve().parent
|
|
|
|
|
| MODEL_PATH = BASE_DIR.parent / "flower-vit"
|
|
|
| EXAMPLE_DIR = BASE_DIR / "example_images"
|
|
|
|
|
|
|
|
|
|
|
|
|
| CAT_LABELS = ["cheetah", "leopard", "lion", "puma", "tiger"]
|
|
|
|
|
|
|
|
|
|
|
|
|
| print("Loading custom model...")
|
| vit_classifier = pipeline(
|
| "image-classification",
|
| model=str(MODEL_PATH)
|
| )
|
|
|
| print("Loading CLIP model...")
|
| clip_classifier = pipeline(
|
| task="zero-shot-image-classification",
|
| model="openai/clip-vit-base-patch32"
|
| )
|
|
|
|
|
|
|
|
|
|
|
|
|
| def normalize_custom_labels(results):
|
| id2label = {
|
| "LABEL_0": "cheetah",
|
| "LABEL_1": "leopard",
|
| "LABEL_2": "lion",
|
| "LABEL_3": "puma",
|
| "LABEL_4": "tiger",
|
| }
|
|
|
| output = {}
|
|
|
| for r in results:
|
| label = r["label"]
|
| score = float(r["score"])
|
|
|
| if label in id2label:
|
| label = id2label[label]
|
| else:
|
| label = label.lower()
|
|
|
| output[label] = score
|
|
|
| return output
|
|
|
|
|
|
|
|
|
|
|
|
|
| def classify_cat(image):
|
|
|
| vit_results = vit_classifier(image)
|
| vit_output = normalize_custom_labels(vit_results)
|
|
|
|
|
| clip_labels = [f"a photo of a {label}" for label in CAT_LABELS]
|
| clip_results = clip_classifier(image, candidate_labels=clip_labels)
|
|
|
| clip_output = {}
|
| for r in clip_results:
|
| label = r["label"].replace("a photo of a ", "").lower()
|
| score = float(r["score"])
|
| clip_output[label] = score
|
|
|
| return vit_output, clip_output
|
|
|
|
|
|
|
|
|
|
|
|
|
| example_images = [
|
| [str(EXAMPLE_DIR / "Cheetah_032.jpg")],
|
| [str(EXAMPLE_DIR / "Leopard_001.jpg")],
|
| [str(EXAMPLE_DIR / "Lion_003.jpg")],
|
| [str(EXAMPLE_DIR / "Puma_001.jpg")],
|
| [str(EXAMPLE_DIR / "Tiger_001.jpg")]
|
| ]
|
|
|
|
|
|
|
|
|
|
|
|
|
| iface = gr.Interface(
|
| fn=classify_cat,
|
| inputs=gr.Image(type="filepath"),
|
| outputs=[
|
| gr.Label(label="Custom Model"),
|
| gr.Label(label="CLIP")
|
| ],
|
| title="Big Cat Classification",
|
| description="Compare Custom Model vs CLIP",
|
| examples=example_images
|
| )
|
|
|
| if __name__ == "__main__":
|
| iface.launch() |