Spaces:
Running
Running
| import os | |
| import json | |
| import base64 | |
| from pathlib import Path | |
| import gradio as gr | |
| from openai import OpenAI | |
| from transformers import pipeline | |
| BASE_DIR = Path(__file__).resolve().parent | |
| EXAMPLE_DIR = BASE_DIR / "example_images" | |
| MODEL_PATH = "DKatheesrupan/cat-vit" | |
| CAT_LABELS = ["cheetah", "leopard", "lion", "puma", "tiger"] | |
| print("Loading custom model...") | |
| vit_classifier = pipeline( | |
| "image-classification", | |
| model=MODEL_PATH | |
| ) | |
| print("Loading CLIP model...") | |
| clip_classifier = pipeline( | |
| task="zero-shot-image-classification", | |
| model="openai/clip-vit-base-patch32" | |
| ) | |
| # OpenAI key comes from Hugging Face Space secret: OPENAI_API_KEY | |
| client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
| # ---------------------------- | |
| # Helper functions | |
| # ---------------------------- | |
| def encode_image(image_path): | |
| with open(image_path, "rb") as f: | |
| return base64.b64encode(f.read()).decode("utf-8") | |
| def normalize_custom_labels(results): | |
| id2label = { | |
| "LABEL_0": "cheetah", | |
| "LABEL_1": "leopard", | |
| "LABEL_2": "lion", | |
| "LABEL_3": "puma", | |
| "LABEL_4": "tiger", | |
| } | |
| output = {} | |
| for r in results: | |
| label = r["label"] | |
| score = float(r["score"]) | |
| if label in id2label: | |
| label = id2label[label] | |
| else: | |
| label = label.lower() | |
| output[label] = score | |
| return output | |
| def classify_with_openai(image_path): | |
| base64_image = encode_image(image_path) | |
| prompt = f""" | |
| You are a big cat classifier. | |
| Classify the image into exactly one of these labels: | |
| {CAT_LABELS} | |
| Return ONLY valid JSON. | |
| Do not use markdown. | |
| Do not use code fences. | |
| Do not add explanations. | |
| Required format: | |
| {{"label":"one_of_{CAT_LABELS}","confidence":0.0}} | |
| """ | |
| try: | |
| response = client.responses.create( | |
| model="gpt-4.1-mini", | |
| input=[ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "input_text", "text": prompt}, | |
| { | |
| "type": "input_image", | |
| "image_url": f"data:image/jpeg;base64,{base64_image}" | |
| } | |
| ] | |
| } | |
| ] | |
| ) | |
| text = response.output_text.strip() | |
| text = text.replace("```json", "").replace("```", "").strip() | |
| start = text.find("{") | |
| end = text.rfind("}") | |
| if start != -1 and end != -1 and end > start: | |
| text = text[start:end + 1] | |
| result = json.loads(text) | |
| label = str(result["label"]).strip().lower() | |
| confidence = float(result["confidence"]) | |
| if label not in CAT_LABELS: | |
| raise ValueError(f"Invalid label: {label}") | |
| confidence = max(0.0, min(1.0, confidence)) | |
| remaining = 1.0 - confidence | |
| num_other = len(CAT_LABELS) - 1 | |
| distribution = {} | |
| for l in CAT_LABELS: | |
| if l == label: | |
| distribution[l] = confidence | |
| else: | |
| distribution[l] = remaining / num_other | |
| return distribution | |
| except Exception: | |
| return {"unknown": 1.0} | |
| # ---------------------------- | |
| # Main function | |
| # ---------------------------- | |
| def classify_cat(image): | |
| # Custom Model | |
| vit_results = vit_classifier(image) | |
| vit_output = normalize_custom_labels(vit_results) | |
| # CLIP | |
| clip_labels = [f"a photo of a {label}" for label in CAT_LABELS] | |
| clip_results = clip_classifier(image, candidate_labels=clip_labels) | |
| clip_output = {} | |
| for r in clip_results: | |
| label = r["label"].replace("a photo of a ", "").lower() | |
| score = float(r["score"]) | |
| clip_output[label] = score | |
| # OpenAI | |
| openai_output = classify_with_openai(image) | |
| return vit_output, clip_output, openai_output | |
| # ---------------------------- | |
| # Example images | |
| # ---------------------------- | |
| example_images = [ | |
| [str(EXAMPLE_DIR / "Cheetah_032.jpg")], | |
| [str(EXAMPLE_DIR / "Leopard_001.jpg")], | |
| [str(EXAMPLE_DIR / "Lion_003.jpg")], | |
| [str(EXAMPLE_DIR / "Puma_001.jpg")], | |
| [str(EXAMPLE_DIR / "Tiger_001.jpg")] | |
| ] | |
| # ---------------------------- | |
| # Interface | |
| # ---------------------------- | |
| iface = gr.Interface( | |
| fn=classify_cat, | |
| inputs=gr.Image(type="filepath"), | |
| outputs=[ | |
| gr.Label(label="Custom Model"), | |
| gr.Label(label="CLIP"), | |
| gr.Label(label="OpenAI") | |
| ], | |
| title="Big Cat Classification", | |
| description="Compare Custom Model vs CLIP vs OpenAI", | |
| examples=example_images | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() |