Spaces:
Runtime error
Runtime error
| import sys | |
| sys.path.insert(0, './code') | |
| from datamodules.transformations import UnNest | |
| from models.interpretation import ImageInterpretationNet | |
| from transformers import ViTFeatureExtractor, ViTForImageClassification | |
| from utils.plot import smoothen, draw_mask_on_image, draw_heatmap_on_image | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| # Load Vision Transformer | |
| hf_model = "tanlq/vit-base-patch16-224-in21k-finetuned-cifar10" | |
| hf_model_imagenet = "google/vit-base-patch16-224" | |
| vit = ViTForImageClassification.from_pretrained(hf_model) | |
| vit_imagenet = ViTForImageClassification.from_pretrained(hf_model_imagenet) | |
| vit.eval() | |
| vit_imagenet.eval() | |
| # Load Feature Extractor | |
| feature_extractor = ViTFeatureExtractor.from_pretrained(hf_model, return_tensors="pt") | |
| feature_extractor_imagenet = ViTFeatureExtractor.from_pretrained(hf_model_imagenet, return_tensors="pt") | |
| feature_extractor = UnNest(feature_extractor) | |
| feature_extractor_imagenet = UnNest(feature_extractor_imagenet) | |
| # Load Vision DiffMask | |
| diffmask = ImageInterpretationNet.load_from_checkpoint('checkpoints/diffmask.ckpt') | |
| diffmask.set_vision_transformer(vit) | |
| diffmask_imagenet = ImageInterpretationNet.load_from_checkpoint('checkpoints/diffmask_imagenet.ckpt') | |
| diffmask_imagenet.set_vision_transformer(vit_imagenet) | |
| diffmask.eval() | |
| diffmask_imagenet.eval() | |
| # Define mask plotting functions | |
| def draw_mask(image, mask): | |
| return draw_mask_on_image(image, smoothen(mask))\ | |
| .permute(1, 2, 0)\ | |
| .clip(0, 1)\ | |
| .numpy() | |
| def draw_heatmap(image, mask): | |
| return draw_heatmap_on_image(image, smoothen(mask))\ | |
| .permute(1, 2, 0)\ | |
| .clip(0, 1)\ | |
| .numpy() | |
| # Define callable method for the demo | |
| def get_mask(image, model_name: str): | |
| torch.manual_seed(seed=0) | |
| if image is None: | |
| return None, None, None | |
| if model_name == 'DiffMask-CIFAR-10': | |
| diffmask_model = diffmask | |
| elif model_name == 'DiffMask-ImageNet': | |
| diffmask_model = diffmask_imagenet | |
| # Helper function to convert class index to name | |
| def idx2cname(idx): | |
| return diffmask_model.model.config.id2label[idx] | |
| # Prepare image and pass through Vision DiffMask | |
| image = torch.from_numpy(image).permute(2, 0, 1).float() / 255 | |
| dm_image = feature_extractor(image).unsqueeze(0) | |
| dm_out = diffmask_model.get_mask(dm_image) | |
| # Get mask and apply on image | |
| mask = dm_out["mask"][0].detach() | |
| masked_img = draw_mask(image, mask) | |
| heatmap = draw_heatmap(image, mask) | |
| # Get logits and map to predictions with class names | |
| n_classes = len(diffmask_model.model.config.id2label) | |
| logits_orig = dm_out["logits_orig"][0].detach().softmax(dim=-1) | |
| logits_mask = dm_out["logits"][0].detach().softmax(dim=-1) | |
| orig_probs = {idx2cname(i): logits_orig[i].item() for i in range(n_classes)} | |
| mask_probs = {idx2cname(i): logits_mask[i].item() for i in range(n_classes)} | |
| return np.hstack((masked_img, heatmap)), orig_probs, mask_probs | |
| # Launch demo interface | |
| gr.Interface( | |
| get_mask, | |
| inputs=[ | |
| gr.inputs.Image(label="Input", shape=(224, 224), source="upload", type="numpy"), | |
| gr.inputs.Dropdown(label="Model Name", choices=["DiffMask-ImageNet", "DiffMask-CIFAR-10"]), | |
| ], | |
| outputs=[ | |
| gr.outputs.Image(label="Output"), | |
| gr.outputs.Label(label="Original Prediction", num_top_classes=5), | |
| gr.outputs.Label(label="Masked Prediction", num_top_classes=5), | |
| ], | |
| examples=[["dogcat.jpeg", "DiffMask-ImageNet"], ["elephant-zebra.jpg", "DiffMask-ImageNet"], | |
| ["finch.jpeg", "DiffMask-ImageNet"]], | |
| title="Vision DiffMask Demo", | |
| live=True, | |
| ).launch() | |