Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import clip | |
| from PIL import Image | |
| import numpy as np | |
| import cv2 | |
| from pytorch_grad_cam import FinerCAM | |
| from pytorch_grad_cam.utils.image import show_cam_on_image | |
| from pytorch_grad_cam.utils.model_targets import FinerWeightedTarget | |
| from torchvision.transforms import Compose, Resize, ToTensor, Normalize | |
| from torchvision.transforms import InterpolationMode | |
| import os | |
| # Load CLIP model | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model, _ = clip.load("ViT-B/16", device=device) | |
| BICUBIC = InterpolationMode.BICUBIC | |
| PATCH_SIZE = 16 | |
| MAX_SIZE = 500 # Limit for width and height | |
| def _convert_image_to_rgb(image): | |
| return image.convert("RGB") | |
| def _transform_resize(h, w): | |
| return Compose([ | |
| Resize((h, w), interpolation=BICUBIC), | |
| _convert_image_to_rgb, | |
| ToTensor(), | |
| Normalize((0.48145466, 0.4578275, 0.40821073), | |
| (0.26862954, 0.26130258, 0.27577711)), | |
| ]) | |
| def preprocess_uploaded_image(uploaded_image, ori_height, ori_width, max_size=MAX_SIZE): | |
| scale = min(max_size / ori_height, max_size / ori_width, 1.0) | |
| new_height = int(np.ceil((ori_height * scale) / PATCH_SIZE) * PATCH_SIZE) | |
| new_width = int(np.ceil((ori_width * scale) / PATCH_SIZE) * PATCH_SIZE) | |
| preprocess = _transform_resize(new_height, new_width) | |
| image = preprocess(uploaded_image) | |
| return [image], new_width, new_height | |
| def reshape_transform(tensor, height=28, width=28): | |
| tensor = tensor.permute(1, 0, 2) | |
| result = tensor[:, 1:, :].reshape(tensor.size(0), height, width, tensor.size(2)) | |
| result = result.transpose(2, 3).transpose(1, 2) | |
| return result | |
| def zeroshot_classifier(classnames, templates, model): | |
| with torch.no_grad(): | |
| zeroshot_weights = [] | |
| for classname in classnames: | |
| texts = [template.format(classname) for template in templates] | |
| texts = clip.tokenize(texts).to(device) | |
| class_embeddings = model.encode_text(texts) | |
| class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) | |
| class_embedding = class_embeddings.mean(dim=0) | |
| class_embedding /= class_embedding.norm() | |
| zeroshot_weights.append(class_embedding) | |
| zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device) | |
| return zeroshot_weights.t() | |
| def process_inputs(target_text, compared_text, image, weight_value): | |
| if image is None or not target_text or not compared_text: | |
| return None | |
| # Support multiple targets separated by commas | |
| target_texts = [t.strip() for t in target_text.split(',')] | |
| if not target_texts: | |
| return None | |
| # Use the first target for visualization | |
| primary_target = target_texts[0] | |
| image = image.convert("RGB") | |
| ori_width, ori_height = image.size | |
| ms_imgs, new_width, new_height = preprocess_uploaded_image(image, ori_height, ori_width) | |
| ms_imgs = [ms_imgs[0]] | |
| class_names = [primary_target, compared_text] | |
| text_features = zeroshot_classifier(class_names, ['a clean origami {}.'], model) | |
| target_layers = [model.visual.transformer.resblocks[-1].ln_1] | |
| cam = FinerCAM(model=model, target_layers=target_layers, reshape_transform=reshape_transform) | |
| for ms_image in ms_imgs: | |
| ms_image = ms_image.unsqueeze(0).to(device) | |
| h, w = ms_image.shape[-2], ms_image.shape[-1] | |
| image_features, _ = model.encode_image(ms_image, h, w) | |
| input_tensor = [image_features, text_features, h, w] | |
| targets = [FinerWeightedTarget(0, [1], weight_value)] | |
| grayscale_cam, _, _ = cam(input_tensor=input_tensor, targets=targets, target_size=None, alpha=1, k=3) | |
| grayscale_cam = grayscale_cam[0, :] | |
| grayscale_cam_highres = cv2.resize(grayscale_cam, (ori_width, ori_height)) | |
| np_image = np.array(image) / 255.0 | |
| visualization = show_cam_on_image(np_image, grayscale_cam_highres, use_rgb=True) | |
| result_image = Image.fromarray(visualization) | |
| return result_image | |
| return image | |
| # Examples with multiple targets for each image | |
| examples = [ | |
| ["Car Headlight", "Car", "data/car.png", 1.0], | |
| ["airplane wheel", "airplane", "data/aircraft.png", 1.0], | |
| ["Car wheel", "Car", "data/car2.png", 1.0], | |
| ["Beak", "Bird", "data/bird.png", 1.0], | |
| ["yellow crown", "Bird", "data/104.png", 1.0], | |
| ["Hindwing tails", "Butterfly", "data/butterfly.png", 1.0], | |
| ] | |
| # Gradio Interface | |
| demo = gr.Interface( | |
| fn=process_inputs, | |
| inputs=[ | |
| gr.Textbox(label="Target"), | |
| gr.Textbox(label="Compared Target"), | |
| gr.Image(label="Upload Image", type="pil"), | |
| gr.Slider(label="Comparison strength", minimum=0.0, maximum=1.0, value=1.0, step=0.01), | |
| ], | |
| outputs=gr.Image(label="FinerCAM Output"), | |
| title="FinerCAM Visualizer with CLIP", | |
| description="Upload an image and enter target texts (comma-separated for multiple targets). Adjust the comparison strength. (Finer-CAM degrades to Grad-CAM when comparison strength = 0)", | |
| examples=examples | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |