|
|
import os, zipfile, shutil |
|
|
import torch, clip |
|
|
from PIL import Image |
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
device = "cpu" |
|
|
model, preprocess = clip.load("ViT-B/32", device=device) |
|
|
|
|
|
|
|
|
landmarks = [ |
|
|
"Starfield library", |
|
|
"Statue of King Sejong", |
|
|
"Cheomseongdae", |
|
|
"N Seoul Tower", |
|
|
"63 Building", |
|
|
"Jongno Tower", |
|
|
"Gocheok Sky Dome", |
|
|
"Myeongdong Cathedral" |
|
|
] |
|
|
text_tokens = clip.tokenize(landmarks).to(device) |
|
|
|
|
|
|
|
|
def classify_images(zip_file, threshold=0.3): |
|
|
|
|
|
work_dir = "workspace" |
|
|
images_dir = os.path.join(work_dir, "images") |
|
|
matched_dir = os.path.join(work_dir, "matched") |
|
|
|
|
|
shutil.rmtree(work_dir, ignore_errors=True) |
|
|
os.makedirs(images_dir, exist_ok=True) |
|
|
os.makedirs(matched_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
with zipfile.ZipFile(zip_file, 'r') as zip_ref: |
|
|
zip_ref.extractall(images_dir) |
|
|
|
|
|
results = [] |
|
|
|
|
|
for f in os.listdir(images_dir): |
|
|
if f.lower().endswith(('.jpg', '.jpeg', '.png')): |
|
|
path = os.path.join(images_dir, f) |
|
|
try: |
|
|
image = preprocess(Image.open(path)).unsqueeze(0).to(device) |
|
|
with torch.no_grad(): |
|
|
logits_per_image, _ = model(image, text_tokens) |
|
|
probs = logits_per_image.softmax(dim=-1).cpu().numpy()[0] |
|
|
|
|
|
best_idx = int(np.argmax(probs)) |
|
|
best_label = landmarks[best_idx] |
|
|
best_conf = float(probs[best_idx]) |
|
|
|
|
|
if best_conf > threshold: |
|
|
shutil.copy(path, os.path.join(matched_dir, f)) |
|
|
results.append([f, best_label, round(best_conf, 2)]) |
|
|
except Exception as e: |
|
|
results.append([f, f"Error: {e}", 0.0]) |
|
|
|
|
|
|
|
|
matched_zip = os.path.join(work_dir, "matched_images.zip") |
|
|
shutil.make_archive(matched_zip[:-4], 'zip', matched_dir) |
|
|
|
|
|
return results, matched_zip |
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=classify_images, |
|
|
inputs=[ |
|
|
gr.File(label="Upload .zip of images", file_types=[".zip"]), |
|
|
gr.Slider(0.0, 1.0, value=0.3, label="Confidence Threshold") |
|
|
], |
|
|
outputs=[ |
|
|
gr.Dataframe(headers=["Filename", "Predicted Landmark", "Confidence"], label="Matching Images"), |
|
|
gr.File(label="Download Matched Images (.zip)") |
|
|
], |
|
|
title="Landmark Classifier (CLIP)", |
|
|
description="Upload a .zip folder of images. The app finds images of these 8 landmarks:\n" |
|
|
+ ", ".join(landmarks) |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|