itsJasminZWIN's picture
Update app.py
bdea9e6 verified
import gradio as gr
from transformers import pipeline
import torchvision.transforms as transforms
from PIL import Image
import torch
from torchvision import models
from huggingface_hub import hf_hub_download
# Download and load ViT model weights
model_path = hf_hub_download("itsJasminZWIN/chihiro-classifier", filename="chihiro_classifier.pth")
vit_classifier = models.vit_b_16(weights=None)
vit_classifier.heads.head = torch.nn.Linear(vit_classifier.heads.head.in_features, 2)
vit_classifier.load_state_dict(torch.load(model_path, map_location="cpu"))
vit_classifier.eval()
# Load CLIP zero-shot model
clip_detector = pipeline(
model="openai/clip-vit-base-patch32",
task="zero-shot-image-classification",
device=0 if torch.cuda.is_available() else -1 # use GPU if available
)
# Labels for both classifiers
label_names = ["chihiro", "not chihiro"]
# Image transform for ViT
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
# Classification function
def classify_image(image):
if isinstance(image, str):
image = Image.open(image).convert("RGB")
# ViT
img_tensor = transform(image).unsqueeze(0)
with torch.no_grad():
outputs = vit_classifier(img_tensor)
probs = torch.nn.functional.softmax(outputs, dim=1)[0]
vit_output = {label_names[i]: round(float(probs[i]), 4) for i in range(2)}
# CLIP
clip_results = clip_detector(image, candidate_labels=label_names)
clip_output = {res["label"]: round(res["score"], 4) for res in clip_results}
return vit_output, clip_output
# Example images from local repo
example_images = [
["example_images/000002.png"],
["example_images/000011.jpg"],
["example_images/000048.png"],
["example_images/Chihiro_13.PNG"],
["example_images/Kiki_01.PNG"],
["example_images/not_chihiro01.jpg"],
["example_images/not_chihiro02.jpg"],
["example_images/chihiro_01.jpg"],
]
clip_cache = {}
def get_clip_prediction(image):
key = hash(image.tobytes()) # crude hash of image content
if key not in clip_cache:
results = clip_detector(image, candidate_labels=label_names)
clip_cache[key] = {res["label"]: round(res["score"], 4) for res in results}
return clip_cache[key]
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("## Chihiro Classifier Comparison")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload or Select Image")
submit_button = gr.Button("Classify")
with gr.Column():
vit_output = gr.Label(label="ViT Classification")
clip_output = gr.Label(label="CLIP Zero-Shot Classification")
submit_button.click(classify_image, inputs=image_input, outputs=[vit_output, clip_output])
gr.Markdown("### 🧪 Example Images")
with gr.Tabs():
with gr.Tab("🧠 Trained Images"):
gr.Examples(
examples=[
["example_images/Kiki_01.PNG"],
["example_images/000048.png"],
["example_images/Chihiro_13.PNG"]
],
inputs=image_input
)
with gr.Tab("🌐 Foreign Images"):
gr.Examples(
examples=[
["example_images/not_chihiro01.jpg"],
["example_images/not_chihiro02.jpg"],
["example_images/chihiro_01.jpg"],
],
inputs=image_input
)
demo.launch()