Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import pipeline | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| import torch | |
| from torchvision import models | |
| from huggingface_hub import hf_hub_download | |
| # Download and load ViT model weights | |
| model_path = hf_hub_download("itsJasminZWIN/chihiro-classifier", filename="chihiro_classifier.pth") | |
| vit_classifier = models.vit_b_16(weights=None) | |
| vit_classifier.heads.head = torch.nn.Linear(vit_classifier.heads.head.in_features, 2) | |
| vit_classifier.load_state_dict(torch.load(model_path, map_location="cpu")) | |
| vit_classifier.eval() | |
| # Load CLIP zero-shot model | |
| clip_detector = pipeline( | |
| model="openai/clip-vit-base-patch32", | |
| task="zero-shot-image-classification", | |
| device=0 if torch.cuda.is_available() else -1 # use GPU if available | |
| ) | |
| # Labels for both classifiers | |
| label_names = ["chihiro", "not chihiro"] | |
| # Image transform for ViT | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor() | |
| ]) | |
| # Classification function | |
| def classify_image(image): | |
| if isinstance(image, str): | |
| image = Image.open(image).convert("RGB") | |
| # ViT | |
| img_tensor = transform(image).unsqueeze(0) | |
| with torch.no_grad(): | |
| outputs = vit_classifier(img_tensor) | |
| probs = torch.nn.functional.softmax(outputs, dim=1)[0] | |
| vit_output = {label_names[i]: round(float(probs[i]), 4) for i in range(2)} | |
| # CLIP | |
| clip_results = clip_detector(image, candidate_labels=label_names) | |
| clip_output = {res["label"]: round(res["score"], 4) for res in clip_results} | |
| return vit_output, clip_output | |
| # Example images from local repo | |
| example_images = [ | |
| ["example_images/000002.png"], | |
| ["example_images/000011.jpg"], | |
| ["example_images/000048.png"], | |
| ["example_images/Chihiro_13.PNG"], | |
| ["example_images/Kiki_01.PNG"], | |
| ["example_images/not_chihiro01.jpg"], | |
| ["example_images/not_chihiro02.jpg"], | |
| ["example_images/chihiro_01.jpg"], | |
| ] | |
| clip_cache = {} | |
| def get_clip_prediction(image): | |
| key = hash(image.tobytes()) # crude hash of image content | |
| if key not in clip_cache: | |
| results = clip_detector(image, candidate_labels=label_names) | |
| clip_cache[key] = {res["label"]: round(res["score"], 4) for res in results} | |
| return clip_cache[key] | |
| # Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Chihiro Classifier Comparison") | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(type="pil", label="Upload or Select Image") | |
| submit_button = gr.Button("Classify") | |
| with gr.Column(): | |
| vit_output = gr.Label(label="ViT Classification") | |
| clip_output = gr.Label(label="CLIP Zero-Shot Classification") | |
| submit_button.click(classify_image, inputs=image_input, outputs=[vit_output, clip_output]) | |
| gr.Markdown("### 🧪 Example Images") | |
| with gr.Tabs(): | |
| with gr.Tab("🧠 Trained Images"): | |
| gr.Examples( | |
| examples=[ | |
| ["example_images/Kiki_01.PNG"], | |
| ["example_images/000048.png"], | |
| ["example_images/Chihiro_13.PNG"] | |
| ], | |
| inputs=image_input | |
| ) | |
| with gr.Tab("🌐 Foreign Images"): | |
| gr.Examples( | |
| examples=[ | |
| ["example_images/not_chihiro01.jpg"], | |
| ["example_images/not_chihiro02.jpg"], | |
| ["example_images/chihiro_01.jpg"], | |
| ], | |
| inputs=image_input | |
| ) | |
| demo.launch() | |