Spaces:
Runtime error
Runtime error
| import gradio as gd | |
| import numpy as np | |
| import os | |
| import torch | |
| import torchvision | |
| import torchvision.models as models | |
| from lime import lime_image | |
| import matplotlib.pyplot as plt | |
| import matplotlib | |
| import torch.nn.functional as F | |
| from skimage.segmentation import mark_boundaries | |
| from PIL import Image | |
| matplotlib.use('agg') | |
| def run_lime(input_image, | |
| model_name: str, | |
| top_labels: int, | |
| num_samples: int, | |
| num_features: int, | |
| batch_size: int): | |
| # input_image is a numpy array of shape (height, width, channels) | |
| # range is [0, 255] | |
| print('model_name', model_name) | |
| print('top_labels', top_labels) | |
| print('num_samples', num_samples) | |
| print('num_features', num_features) | |
| print('batch_size', batch_size) | |
| print('input image', type(input_image), input_image.shape) | |
| model, weights = fetch_model(model_name) | |
| preprocess = weights.transforms(antialias=True) | |
| input_image_processed = preprocess(torch.from_numpy(input_image.transpose(2,0,1))).unsqueeze(0) | |
| logits = model(input_image_processed) | |
| probs = F.softmax(logits, dim=1) | |
| names = weights.meta['categories'] | |
| top_10_classes = [] | |
| print('probs', type(probs), probs.shape) | |
| for x in probs.argsort(descending=True)[0][:10]: | |
| print(x.item(), names[x], probs[0,x].item()) | |
| top_10_classes.append([x.item(), names[x], probs[0,x].item()]) | |
| def classifier_fn(images): | |
| print('classifier_fn', type(images), images.shape) | |
| zz = preprocess(torch.from_numpy(images[0].transpose(2,0,1))) | |
| c, w, h = zz.shape | |
| batch = torch.zeros(batch_size, c, w, h) | |
| print('len(images)', len(images)) | |
| for i in range(batch_size): | |
| batch[i] = preprocess(torch.from_numpy(images[i].transpose(2,0,1))) | |
| print('batch', type(batch), batch.shape) | |
| logits = model(batch) | |
| probs = F.softmax(logits, dim=1) | |
| print('probs', type(probs), probs.shape) | |
| return probs.detach().cpu().numpy() | |
| explainer = lime_image.LimeImageExplainer() | |
| explanation = explainer.explain_instance( | |
| input_image, | |
| classifier_fn, | |
| top_labels=top_labels, | |
| hide_color=0, | |
| num_samples=num_samples, | |
| num_features=num_features, | |
| batch_size=batch_size) | |
| temp, mask = explanation.get_image_and_mask( | |
| explanation.top_labels[0], | |
| positive_only=False, num_features=num_features, hide_rest=False) | |
| lime_output = mark_boundaries(temp/255.0, mask) | |
| return lime_output, top_10_classes | |
| def fetch_model_names(): | |
| return models.list_models(module=torchvision.models) | |
| def fetch_model(model_name): | |
| print('Retrieving model ', model_name) | |
| weights_enum = models.get_model_weights(model_name) | |
| for w in weights_enum: | |
| if "IMAGENET1K" in w.name: | |
| weights = w | |
| model = models.get_model(model_name, weights=weights) | |
| print('Model weights loaded', w.name) | |
| return model, weights | |
| return None, None | |
| with gd.Blocks() as demo: | |
| with gd.Column(): | |
| gd.Markdown(value=''' | |
| # A simple GUI for LIME | |
| This is a simple GUI for Local Interpretable Model-agnostic Explanations (LIME). It allows you to run LIME on a variety of models and images. I've used the following resources to build this GUI: | |
| * [LIME](https://github.com/marcotcr/lime) | |
| * [LIME tutorial](https://github.com/marcotcr/lime/blob/master/tutorials/lime_image.ipynb) | |
| ''') | |
| with gd.Row(): | |
| input_image = gd.Image(label="Input Image. Please upload an image that you want LIME to explain") | |
| with gd.Column(): | |
| model_name = gd.Dropdown(label="Model", | |
| info=''' | |
| Select the image classification model to use for LIME. | |
| The list is automatically populated by using torchvision library. | |
| ''', | |
| value='convnext_tiny', | |
| choices=fetch_model_names()) | |
| top_labels = gd.Number(label='top_labels',info=''' | |
| use the first <top_labels> labels to create explanations. | |
| For example, setting top_labels=5 will create explanations | |
| for the top 5 most likely classes.''', | |
| precision=0, value=5) | |
| num_samples = gd.Number(label="num_samples", | |
| info="How many samples to be created to build the linear model inside LIME", | |
| precision=0, value=100) | |
| with gd.Column(): | |
| num_features = gd.Number(label="num_features", | |
| info='Among the most important superpixels (features), how many to be shown in the explanation image', | |
| precision=0, value=2) | |
| batch_size = gd.Number(label="batch_size", | |
| info='how many images in the samples to be processed at once', | |
| precision=0, value=20) | |
| run_button = gd.Button(label="Run") | |
| with gd.Row(): | |
| top_10_classes = gd.DataFrame(label="Top 10 classes", | |
| info="Top-10 classes for the input image calculated by using the selected model", | |
| headers=["class_id","label","probability"], | |
| datatype=["number","str","number"]) | |
| lime_output = gd.Image(label="Lime Explanation", | |
| info="The explanation image for the input image calculated by LIME for the selected model") | |
| gd.Examples( | |
| label="Some examples images and parameters", | |
| examples=[["jeep.png","convnext_tiny",5,100,2,20], | |
| ["IMG_0154.jpg","convnext_tiny",5,100,2,20], | |
| ["IMG_0155.jpg","convnext_tiny",5,100,2,20], | |
| ["IMG_0156.jpg","convnext_tiny",5,100,2,20], | |
| ["IMG_0157.jpg","convnext_tiny",5,100,2,20], | |
| ["IMG_0158.jpg","convnext_tiny",5,100,2,20], | |
| ["IMG_0159.jpg","convnext_tiny",5,100,2,20], | |
| ["IMG_0160.jpg","convnext_tiny",5,100,2,20]], | |
| inputs=[input_image,model_name,top_labels,num_samples,num_features,batch_size]) | |
| run_button.click(fn=run_lime,inputs=[input_image, model_name, top_labels,num_samples,num_features,batch_size], | |
| outputs=[lime_output,top_10_classes]) | |
| if __name__ == "__main__": | |
| demo.launch() | |