Image / app.py
arihar's picture
Create app.py
d337246 verified
import os, zipfile, shutil
import torch, clip
from PIL import Image
import gradio as gr
import numpy as np
# Load CLIP model (CPU mode works on Spaces)
device = "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
# Your 8 landmark prompts
landmarks = [
"Starfield library",
"Statue of King Sejong",
"Cheomseongdae",
"N Seoul Tower",
"63 Building",
"Jongno Tower",
"Gocheok Sky Dome",
"Myeongdong Cathedral"
]
text_tokens = clip.tokenize(landmarks).to(device)
def classify_images(zip_file, threshold=0.3):
# Prepare temp folders
work_dir = "workspace"
images_dir = os.path.join(work_dir, "images")
matched_dir = os.path.join(work_dir, "matched")
shutil.rmtree(work_dir, ignore_errors=True)
os.makedirs(images_dir, exist_ok=True)
os.makedirs(matched_dir, exist_ok=True)
# Extract uploaded zip
with zipfile.ZipFile(zip_file, 'r') as zip_ref:
zip_ref.extractall(images_dir)
results = []
# Loop over images
for f in os.listdir(images_dir):
if f.lower().endswith(('.jpg', '.jpeg', '.png')):
path = os.path.join(images_dir, f)
try:
image = preprocess(Image.open(path)).unsqueeze(0).to(device)
with torch.no_grad():
logits_per_image, _ = model(image, text_tokens)
probs = logits_per_image.softmax(dim=-1).cpu().numpy()[0]
best_idx = int(np.argmax(probs))
best_label = landmarks[best_idx]
best_conf = float(probs[best_idx])
if best_conf > threshold:
shutil.copy(path, os.path.join(matched_dir, f))
results.append([f, best_label, round(best_conf, 2)])
except Exception as e:
results.append([f, f"Error: {e}", 0.0])
# Create zip of matched images
matched_zip = os.path.join(work_dir, "matched_images.zip")
shutil.make_archive(matched_zip[:-4], 'zip', matched_dir)
return results, matched_zip
demo = gr.Interface(
fn=classify_images,
inputs=[
gr.File(label="Upload .zip of images", file_types=[".zip"]),
gr.Slider(0.0, 1.0, value=0.3, label="Confidence Threshold")
],
outputs=[
gr.Dataframe(headers=["Filename", "Predicted Landmark", "Confidence"], label="Matching Images"),
gr.File(label="Download Matched Images (.zip)")
],
title="Landmark Classifier (CLIP)",
description="Upload a .zip folder of images. The app finds images of these 8 landmarks:\n"
+ ", ".join(landmarks)
)
if __name__ == "__main__":
demo.launch()