Spaces:
Runtime error
Runtime error
File size: 6,871 Bytes
1d12695 9787127 1d12695 9787127 1d12695 9787127 1d12695 9787127 1d12695 9787127 1d12695 9787127 1d12695 9787127 1d12695 9787127 1d12695 9787127 1d12695 9787127 1d12695 9787127 1d12695 9787127 1d12695 9787127 1d12695 9787127 8f0fbfc 721bced 1d12695 9787127 1d12695 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import gradio as gd
import numpy as np
import os
import torch
import torchvision
import torchvision.models as models
from lime import lime_image
import matplotlib.pyplot as plt
import matplotlib
import torch.nn.functional as F
from skimage.segmentation import mark_boundaries
from PIL import Image
matplotlib.use('agg')
def run_lime(input_image,
model_name: str,
top_labels: int,
num_samples: int,
num_features: int,
batch_size: int):
# input_image is a numpy array of shape (height, width, channels)
# range is [0, 255]
print('model_name', model_name)
print('top_labels', top_labels)
print('num_samples', num_samples)
print('num_features', num_features)
print('batch_size', batch_size)
print('input image', type(input_image), input_image.shape)
model, weights = fetch_model(model_name)
preprocess = weights.transforms(antialias=True)
input_image_processed = preprocess(torch.from_numpy(input_image.transpose(2,0,1))).unsqueeze(0)
logits = model(input_image_processed)
probs = F.softmax(logits, dim=1)
names = weights.meta['categories']
top_10_classes = []
print('probs', type(probs), probs.shape)
for x in probs.argsort(descending=True)[0][:10]:
print(x.item(), names[x], probs[0,x].item())
top_10_classes.append([x.item(), names[x], probs[0,x].item()])
def classifier_fn(images):
print('classifier_fn', type(images), images.shape)
zz = preprocess(torch.from_numpy(images[0].transpose(2,0,1)))
c, w, h = zz.shape
batch = torch.zeros(batch_size, c, w, h)
print('len(images)', len(images))
for i in range(batch_size):
batch[i] = preprocess(torch.from_numpy(images[i].transpose(2,0,1)))
print('batch', type(batch), batch.shape)
logits = model(batch)
probs = F.softmax(logits, dim=1)
print('probs', type(probs), probs.shape)
return probs.detach().cpu().numpy()
explainer = lime_image.LimeImageExplainer()
explanation = explainer.explain_instance(
input_image,
classifier_fn,
top_labels=top_labels,
hide_color=0,
num_samples=num_samples,
num_features=num_features,
batch_size=batch_size)
temp, mask = explanation.get_image_and_mask(
explanation.top_labels[0],
positive_only=False, num_features=num_features, hide_rest=False)
lime_output = mark_boundaries(temp/255.0, mask)
return lime_output, top_10_classes
def fetch_model_names():
return models.list_models(module=torchvision.models)
def fetch_model(model_name):
print('Retrieving model ', model_name)
weights_enum = models.get_model_weights(model_name)
for w in weights_enum:
if "IMAGENET1K" in w.name:
weights = w
model = models.get_model(model_name, weights=weights)
print('Model weights loaded', w.name)
return model, weights
return None, None
with gd.Blocks() as demo:
with gd.Column():
gd.Markdown(value='''
# A simple GUI for LIME
This is a simple GUI for Local Interpretable Model-agnostic Explanations (LIME). It allows you to run LIME on a variety of models and images. I've used the following resources to build this GUI:
* [LIME](https://github.com/marcotcr/lime)
* [LIME tutorial](https://github.com/marcotcr/lime/blob/master/tutorials/lime_image.ipynb)
''')
with gd.Row():
input_image = gd.Image(label="Input Image. Please upload an image that you want LIME to explain")
with gd.Column():
model_name = gd.Dropdown(label="Model",
info='''
Select the image classification model to use for LIME.
The list is automatically populated by using torchvision library.
''',
value='convnext_tiny',
choices=fetch_model_names())
top_labels = gd.Number(label='top_labels',info='''
use the first <top_labels> labels to create explanations.
For example, setting top_labels=5 will create explanations
for the top 5 most likely classes.''',
precision=0, value=5)
num_samples = gd.Number(label="num_samples",
info="How many samples to be created to build the linear model inside LIME",
precision=0, value=100)
with gd.Column():
num_features = gd.Number(label="num_features",
info='Among the most important superpixels (features), how many to be shown in the explanation image',
precision=0, value=2)
batch_size = gd.Number(label="batch_size",
info='how many images in the samples to be processed at once',
precision=0, value=20)
run_button = gd.Button(label="Run")
with gd.Row():
top_10_classes = gd.DataFrame(label="Top 10 classes",
info="Top-10 classes for the input image calculated by using the selected model",
headers=["class_id","label","probability"],
datatype=["number","str","number"])
lime_output = gd.Image(label="Lime Explanation",
info="The explanation image for the input image calculated by LIME for the selected model")
gd.Examples(
label="Some examples images and parameters",
examples=[["jeep.png","convnext_tiny",5,100,2,20],
["IMG_0154.jpg","convnext_tiny",5,100,2,20],
["IMG_0155.jpg","convnext_tiny",5,100,2,20],
["IMG_0156.jpg","convnext_tiny",5,100,2,20],
["IMG_0157.jpg","convnext_tiny",5,100,2,20],
["IMG_0158.jpg","convnext_tiny",5,100,2,20],
["IMG_0159.jpg","convnext_tiny",5,100,2,20],
["IMG_0160.jpg","convnext_tiny",5,100,2,20]],
inputs=[input_image,model_name,top_labels,num_samples,num_features,batch_size])
run_button.click(fn=run_lime,inputs=[input_image, model_name, top_labels,num_samples,num_features,batch_size],
outputs=[lime_output,top_10_classes])
if __name__ == "__main__":
demo.launch()
|