Spaces:
Runtime error
Runtime error
| import os | |
| import base64 | |
| import gradio as gr | |
| import openai | |
| from transformers import pipeline | |
| FLOWER_LABELS = ["daisy", "dandelion", "rose", "sunflower", "tulip"] | |
| vit_classifier = pipeline( | |
| "image-classification", | |
| model="./flower-vit-classifier" | |
| ) | |
| clip_classifier = pipeline( | |
| task="zero-shot-image-classification", | |
| model="openai/clip-vit-large-patch14" | |
| ) | |
| client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
| def normalize_label(text): | |
| text = text.strip().lower().replace(".", "").replace(",", "") | |
| for label in FLOWER_LABELS: | |
| if label in text: | |
| return label | |
| return text | |
| def classify_openai(image_path): | |
| with open(image_path, "rb") as f: | |
| b64 = base64.b64encode(f.read()).decode("utf-8") | |
| prompt = f""" | |
| You are a flower image classifier. | |
| Classify the image into exactly one of these labels: | |
| {FLOWER_LABELS} | |
| Return only one label and nothing else. | |
| """ | |
| response = client.responses.create( | |
| model="gpt-4.1-mini", | |
| input=[{ | |
| "role": "user", | |
| "content": [ | |
| {"type": "input_text", "text": prompt}, | |
| {"type": "input_image", "image_url": f"data:image/jpeg;base64,{b64}"} | |
| ] | |
| }] | |
| ) | |
| return normalize_label(response.output_text) | |
| def classify_flower(image): | |
| vit_results = vit_classifier(image) | |
| vit_output = {x["label"]: float(x["score"]) for x in vit_results} | |
| clip_results = clip_classifier(image, candidate_labels=FLOWER_LABELS) | |
| clip_output = {x["label"]: float(x["score"]) for x in clip_results} | |
| openai_label = classify_openai(image) | |
| return { | |
| "your_vit_model": vit_output, | |
| "clip_zero_shot": clip_output, | |
| "openai_prediction": openai_label | |
| } | |
| example_images = [ | |
| ["example_images/train/daisy/2488902131_3417698611_n.jpg"], | |
| ["example_images/train/dandelion/3554992110_81d8c9b0bd_m.jpg"], | |
| ["example_images/train/rose/3415176946_248afe9f32.jpg"], | |
| ["example_images/train/sunflower/4932144003_cbffc89bf0.jpg"], | |
| ["example_images/train/tulip/4579128789_1561575458_n.jpg"] | |
| ] | |
| iface = gr.Interface( | |
| fn=classify_flower, | |
| inputs=gr.Image(type="filepath"), | |
| outputs=gr.JSON(), | |
| title="Flower Classification Comparison", | |
| description="Compare a fine-tuned ViT model, CLIP zero-shot, and OpenAI vision on flower images.", | |
| examples=example_images | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() | |