import os from pathlib import Path import gradio as gr from transformers import pipeline # ---------------------------- # Paths # ---------------------------- BASE_DIR = Path(__file__).resolve().parent # HIER ggf. den Modellordner anpassen MODEL_PATH = BASE_DIR.parent / "flower-vit" EXAMPLE_DIR = BASE_DIR / "example_images" # ---------------------------- # Labels # ---------------------------- CAT_LABELS = ["cheetah", "leopard", "lion", "puma", "tiger"] # ---------------------------- # Load models # ---------------------------- print("Loading custom model...") vit_classifier = pipeline( "image-classification", model=str(MODEL_PATH) ) print("Loading CLIP model...") clip_classifier = pipeline( task="zero-shot-image-classification", model="openai/clip-vit-base-patch32" ) # ---------------------------- # Helper functions # ---------------------------- def normalize_custom_labels(results): id2label = { "LABEL_0": "cheetah", "LABEL_1": "leopard", "LABEL_2": "lion", "LABEL_3": "puma", "LABEL_4": "tiger", } output = {} for r in results: label = r["label"] score = float(r["score"]) if label in id2label: label = id2label[label] else: label = label.lower() output[label] = score return output # ---------------------------- # Main function # ---------------------------- def classify_cat(image): # Custom Model vit_results = vit_classifier(image) vit_output = normalize_custom_labels(vit_results) # CLIP clip_labels = [f"a photo of a {label}" for label in CAT_LABELS] clip_results = clip_classifier(image, candidate_labels=clip_labels) clip_output = {} for r in clip_results: label = r["label"].replace("a photo of a ", "").lower() score = float(r["score"]) clip_output[label] = score return vit_output, clip_output # ---------------------------- # Example images # ---------------------------- example_images = [ [str(EXAMPLE_DIR / "Cheetah_032.jpg")], [str(EXAMPLE_DIR / "Leopard_001.jpg")], [str(EXAMPLE_DIR / "Lion_003.jpg")], [str(EXAMPLE_DIR / "Puma_001.jpg")], [str(EXAMPLE_DIR / "Tiger_001.jpg")] ] # ---------------------------- # Interface # ---------------------------- iface = gr.Interface( fn=classify_cat, inputs=gr.Image(type="filepath"), outputs=[ gr.Label(label="Custom Model"), gr.Label(label="CLIP") ], title="Big Cat Classification", description="Compare Custom Model vs CLIP", examples=example_images ) if __name__ == "__main__": iface.launch()