imagelanguage / app.py
TangYiJay's picture
app.py
366963e verified
raw
history blame
2.26 kB
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import gradio as gr
import torch
import numpy as np
MODEL_ID = "openai/clip-vit-base-patch32"
# Load model & processor
model = CLIPModel.from_pretrained(MODEL_ID)
processor = CLIPProcessor.from_pretrained(MODEL_ID)
# Candidate material labels
LABELS = ["plastic", "metal", "paper", "cardboard", "glass", "trash"]
def get_image_embedding(image):
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
embedding = model.get_image_features(**inputs)
embedding = embedding / embedding.norm(p=2, dim=-1, keepdim=True)
return embedding.cpu().numpy()
def classify_material(base_img, target_img):
if base_img is None or target_img is None:
return "Please upload both base and target images."
# Compute embeddings
base_emb = get_image_embedding(base_img)
target_emb = get_image_embedding(target_img)
# Difference score
diff = np.linalg.norm(target_emb - base_emb)
# Text embeddings for all labels
text_inputs = processor(text=LABELS, return_tensors="pt", padding=True)
with torch.no_grad():
text_emb = model.get_text_features(**text_inputs)
text_emb = text_emb / text_emb.norm(p=2, dim=-1, keepdim=True)
# Compute similarity with target image
img_inputs = processor(images=target_img, return_tensors="pt")
with torch.no_grad():
img_feat = model.get_image_features(**img_inputs)
img_feat = img_feat / img_feat.norm(p=2, dim=-1, keepdim=True)
sims = torch.matmul(img_feat, text_emb.T).squeeze(0)
best_idx = torch.argmax(sims).item()
best_label = LABELS[best_idx]
return f"Detected material: {best_label}\nDifference from base: {diff:.4f}"
demo = gr.Interface(
fn=classify_material,
inputs=[
gr.Image(type="pil", label="Base Image"),
gr.Image(type="pil", label="Target Image")
],
outputs=gr.Textbox(label="Detection Result"),
title="Material Classification (CLIP, CPU Mode)",
description="Upload a base image (background) and a target image (with object). The model detects what new material appears: plastic, metal, paper, cardboard, glass, or trash."
)
if __name__ == "__main__":
demo.launch()