ZihengZ's picture
update
d6c35f4
import gradio as gr
import torch
import clip
from PIL import Image
import numpy as np
import cv2
from pytorch_grad_cam import FinerCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import FinerWeightedTarget
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from torchvision.transforms import InterpolationMode
import os
# Load CLIP model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, _ = clip.load("ViT-B/16", device=device)
BICUBIC = InterpolationMode.BICUBIC
PATCH_SIZE = 16
MAX_SIZE = 500 # Limit for width and height
def _convert_image_to_rgb(image):
return image.convert("RGB")
def _transform_resize(h, w):
return Compose([
Resize((h, w), interpolation=BICUBIC),
_convert_image_to_rgb,
ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711)),
])
def preprocess_uploaded_image(uploaded_image, ori_height, ori_width, max_size=MAX_SIZE):
scale = min(max_size / ori_height, max_size / ori_width, 1.0)
new_height = int(np.ceil((ori_height * scale) / PATCH_SIZE) * PATCH_SIZE)
new_width = int(np.ceil((ori_width * scale) / PATCH_SIZE) * PATCH_SIZE)
preprocess = _transform_resize(new_height, new_width)
image = preprocess(uploaded_image)
return [image], new_width, new_height
def reshape_transform(tensor, height=28, width=28):
tensor = tensor.permute(1, 0, 2)
result = tensor[:, 1:, :].reshape(tensor.size(0), height, width, tensor.size(2))
result = result.transpose(2, 3).transpose(1, 2)
return result
def zeroshot_classifier(classnames, templates, model):
with torch.no_grad():
zeroshot_weights = []
for classname in classnames:
texts = [template.format(classname) for template in templates]
texts = clip.tokenize(texts).to(device)
class_embeddings = model.encode_text(texts)
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
class_embedding = class_embeddings.mean(dim=0)
class_embedding /= class_embedding.norm()
zeroshot_weights.append(class_embedding)
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device)
return zeroshot_weights.t()
def process_inputs(target_text, compared_text, image, weight_value):
if image is None or not target_text or not compared_text:
return None
# Support multiple targets separated by commas
target_texts = [t.strip() for t in target_text.split(',')]
if not target_texts:
return None
# Use the first target for visualization
primary_target = target_texts[0]
image = image.convert("RGB")
ori_width, ori_height = image.size
ms_imgs, new_width, new_height = preprocess_uploaded_image(image, ori_height, ori_width)
ms_imgs = [ms_imgs[0]]
class_names = [primary_target, compared_text]
text_features = zeroshot_classifier(class_names, ['a clean origami {}.'], model)
target_layers = [model.visual.transformer.resblocks[-1].ln_1]
cam = FinerCAM(model=model, target_layers=target_layers, reshape_transform=reshape_transform)
for ms_image in ms_imgs:
ms_image = ms_image.unsqueeze(0).to(device)
h, w = ms_image.shape[-2], ms_image.shape[-1]
image_features, _ = model.encode_image(ms_image, h, w)
input_tensor = [image_features, text_features, h, w]
targets = [FinerWeightedTarget(0, [1], weight_value)]
grayscale_cam, _, _ = cam(input_tensor=input_tensor, targets=targets, target_size=None, alpha=1, k=3)
grayscale_cam = grayscale_cam[0, :]
grayscale_cam_highres = cv2.resize(grayscale_cam, (ori_width, ori_height))
np_image = np.array(image) / 255.0
visualization = show_cam_on_image(np_image, grayscale_cam_highres, use_rgb=True)
result_image = Image.fromarray(visualization)
return result_image
return image
# Examples with multiple targets for each image
examples = [
["Car Headlight", "Car", "data/car.png", 1.0],
["airplane wheel", "airplane", "data/aircraft.png", 1.0],
["Car wheel", "Car", "data/car2.png", 1.0],
["Beak", "Bird", "data/bird.png", 1.0],
["yellow crown", "Bird", "data/104.png", 1.0],
["Hindwing tails", "Butterfly", "data/butterfly.png", 1.0],
]
# Gradio Interface
demo = gr.Interface(
fn=process_inputs,
inputs=[
gr.Textbox(label="Target"),
gr.Textbox(label="Compared Target"),
gr.Image(label="Upload Image", type="pil"),
gr.Slider(label="Comparison strength", minimum=0.0, maximum=1.0, value=1.0, step=0.01),
],
outputs=gr.Image(label="FinerCAM Output"),
title="FinerCAM Visualizer with CLIP",
description="Upload an image and enter target texts (comma-separated for multiple targets). Adjust the comparison strength. (Finer-CAM degrades to Grad-CAM when comparison strength = 0)",
examples=examples
)
if __name__ == "__main__":
demo.launch()