import gradio as gd import numpy as np import os import torch import torchvision import torchvision.models as models from lime import lime_image import matplotlib.pyplot as plt import matplotlib import torch.nn.functional as F from skimage.segmentation import mark_boundaries from PIL import Image matplotlib.use('agg') def run_lime(input_image, model_name: str, top_labels: int, num_samples: int, num_features: int, batch_size: int): # input_image is a numpy array of shape (height, width, channels) # range is [0, 255] print('model_name', model_name) print('top_labels', top_labels) print('num_samples', num_samples) print('num_features', num_features) print('batch_size', batch_size) print('input image', type(input_image), input_image.shape) model, weights = fetch_model(model_name) preprocess = weights.transforms(antialias=True) input_image_processed = preprocess(torch.from_numpy(input_image.transpose(2,0,1))).unsqueeze(0) logits = model(input_image_processed) probs = F.softmax(logits, dim=1) names = weights.meta['categories'] top_10_classes = [] print('probs', type(probs), probs.shape) for x in probs.argsort(descending=True)[0][:10]: print(x.item(), names[x], probs[0,x].item()) top_10_classes.append([x.item(), names[x], probs[0,x].item()]) def classifier_fn(images): print('classifier_fn', type(images), images.shape) zz = preprocess(torch.from_numpy(images[0].transpose(2,0,1))) c, w, h = zz.shape batch = torch.zeros(batch_size, c, w, h) print('len(images)', len(images)) for i in range(batch_size): batch[i] = preprocess(torch.from_numpy(images[i].transpose(2,0,1))) print('batch', type(batch), batch.shape) logits = model(batch) probs = F.softmax(logits, dim=1) print('probs', type(probs), probs.shape) return probs.detach().cpu().numpy() explainer = lime_image.LimeImageExplainer() explanation = explainer.explain_instance( input_image, classifier_fn, top_labels=top_labels, hide_color=0, num_samples=num_samples, num_features=num_features, batch_size=batch_size) temp, mask = explanation.get_image_and_mask( explanation.top_labels[0], positive_only=False, num_features=num_features, hide_rest=False) lime_output = mark_boundaries(temp/255.0, mask) return lime_output, top_10_classes def fetch_model_names(): return models.list_models(module=torchvision.models) def fetch_model(model_name): print('Retrieving model ', model_name) weights_enum = models.get_model_weights(model_name) for w in weights_enum: if "IMAGENET1K" in w.name: weights = w model = models.get_model(model_name, weights=weights) print('Model weights loaded', w.name) return model, weights return None, None with gd.Blocks() as demo: with gd.Column(): gd.Markdown(value=''' # A simple GUI for LIME This is a simple GUI for Local Interpretable Model-agnostic Explanations (LIME). It allows you to run LIME on a variety of models and images. I've used the following resources to build this GUI: * [LIME](https://github.com/marcotcr/lime) * [LIME tutorial](https://github.com/marcotcr/lime/blob/master/tutorials/lime_image.ipynb) ''') with gd.Row(): input_image = gd.Image(label="Input Image. Please upload an image that you want LIME to explain") with gd.Column(): model_name = gd.Dropdown(label="Model", info=''' Select the image classification model to use for LIME. The list is automatically populated by using torchvision library. ''', value='convnext_tiny', choices=fetch_model_names()) top_labels = gd.Number(label='top_labels',info=''' use the first labels to create explanations. For example, setting top_labels=5 will create explanations for the top 5 most likely classes.''', precision=0, value=5) num_samples = gd.Number(label="num_samples", info="How many samples to be created to build the linear model inside LIME", precision=0, value=100) with gd.Column(): num_features = gd.Number(label="num_features", info='Among the most important superpixels (features), how many to be shown in the explanation image', precision=0, value=2) batch_size = gd.Number(label="batch_size", info='how many images in the samples to be processed at once', precision=0, value=20) run_button = gd.Button(label="Run") with gd.Row(): top_10_classes = gd.DataFrame(label="Top 10 classes", info="Top-10 classes for the input image calculated by using the selected model", headers=["class_id","label","probability"], datatype=["number","str","number"]) lime_output = gd.Image(label="Lime Explanation", info="The explanation image for the input image calculated by LIME for the selected model") gd.Examples( label="Some examples images and parameters", examples=[["jeep.png","convnext_tiny",5,100,2,20], ["IMG_0154.jpg","convnext_tiny",5,100,2,20], ["IMG_0155.jpg","convnext_tiny",5,100,2,20], ["IMG_0156.jpg","convnext_tiny",5,100,2,20], ["IMG_0157.jpg","convnext_tiny",5,100,2,20], ["IMG_0158.jpg","convnext_tiny",5,100,2,20], ["IMG_0159.jpg","convnext_tiny",5,100,2,20], ["IMG_0160.jpg","convnext_tiny",5,100,2,20]], inputs=[input_image,model_name,top_labels,num_samples,num_features,batch_size]) run_button.click(fn=run_lime,inputs=[input_image, model_name, top_labels,num_samples,num_features,batch_size], outputs=[lime_output,top_10_classes]) if __name__ == "__main__": demo.launch()