File size: 2,798 Bytes
79562ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
from pathlib import Path

import gradio as gr
from transformers import pipeline


# ----------------------------
# Paths
# ----------------------------

BASE_DIR = Path(__file__).resolve().parent

# HIER ggf. den Modellordner anpassen
MODEL_PATH = BASE_DIR.parent / "flower-vit"

EXAMPLE_DIR = BASE_DIR / "example_images"


# ----------------------------
# Labels
# ----------------------------

CAT_LABELS = ["cheetah", "leopard", "lion", "puma", "tiger"]


# ----------------------------
# Load models
# ----------------------------

print("Loading custom model...")
vit_classifier = pipeline(
    "image-classification",
    model=str(MODEL_PATH)
)

print("Loading CLIP model...")
clip_classifier = pipeline(
    task="zero-shot-image-classification",
    model="openai/clip-vit-base-patch32"
)


# ----------------------------
# Helper functions
# ----------------------------

def normalize_custom_labels(results):
    id2label = {
        "LABEL_0": "cheetah",
        "LABEL_1": "leopard",
        "LABEL_2": "lion",
        "LABEL_3": "puma",
        "LABEL_4": "tiger",
    }

    output = {}

    for r in results:
        label = r["label"]
        score = float(r["score"])

        if label in id2label:
            label = id2label[label]
        else:
            label = label.lower()

        output[label] = score

    return output


# ----------------------------
# Main function
# ----------------------------

def classify_cat(image):
    # Custom Model
    vit_results = vit_classifier(image)
    vit_output = normalize_custom_labels(vit_results)

    # CLIP
    clip_labels = [f"a photo of a {label}" for label in CAT_LABELS]
    clip_results = clip_classifier(image, candidate_labels=clip_labels)

    clip_output = {}
    for r in clip_results:
        label = r["label"].replace("a photo of a ", "").lower()
        score = float(r["score"])
        clip_output[label] = score

    return vit_output, clip_output


# ----------------------------
# Example images
# ----------------------------

example_images = [
    [str(EXAMPLE_DIR / "Cheetah_032.jpg")],
    [str(EXAMPLE_DIR / "Leopard_001.jpg")],
    [str(EXAMPLE_DIR / "Lion_003.jpg")],
    [str(EXAMPLE_DIR / "Puma_001.jpg")],
    [str(EXAMPLE_DIR / "Tiger_001.jpg")]
]


# ----------------------------
# Interface
# ----------------------------

iface = gr.Interface(
    fn=classify_cat,
    inputs=gr.Image(type="filepath"),
    outputs=[
        gr.Label(label="Custom Model"),
        gr.Label(label="CLIP")
    ],
    title="Big Cat Classification",
    description="Compare Custom Model vs CLIP",
    examples=example_images
)

if __name__ == "__main__":
    iface.launch()