LIME / app.py
hkayabilisim's picture
Upload 10 files
721bced
import gradio as gd
import numpy as np
import os
import torch
import torchvision
import torchvision.models as models
from lime import lime_image
import matplotlib.pyplot as plt
import matplotlib
import torch.nn.functional as F
from skimage.segmentation import mark_boundaries
from PIL import Image
matplotlib.use('agg')
def run_lime(input_image,
model_name: str,
top_labels: int,
num_samples: int,
num_features: int,
batch_size: int):
# input_image is a numpy array of shape (height, width, channels)
# range is [0, 255]
print('model_name', model_name)
print('top_labels', top_labels)
print('num_samples', num_samples)
print('num_features', num_features)
print('batch_size', batch_size)
print('input image', type(input_image), input_image.shape)
model, weights = fetch_model(model_name)
preprocess = weights.transforms(antialias=True)
input_image_processed = preprocess(torch.from_numpy(input_image.transpose(2,0,1))).unsqueeze(0)
logits = model(input_image_processed)
probs = F.softmax(logits, dim=1)
names = weights.meta['categories']
top_10_classes = []
print('probs', type(probs), probs.shape)
for x in probs.argsort(descending=True)[0][:10]:
print(x.item(), names[x], probs[0,x].item())
top_10_classes.append([x.item(), names[x], probs[0,x].item()])
def classifier_fn(images):
print('classifier_fn', type(images), images.shape)
zz = preprocess(torch.from_numpy(images[0].transpose(2,0,1)))
c, w, h = zz.shape
batch = torch.zeros(batch_size, c, w, h)
print('len(images)', len(images))
for i in range(batch_size):
batch[i] = preprocess(torch.from_numpy(images[i].transpose(2,0,1)))
print('batch', type(batch), batch.shape)
logits = model(batch)
probs = F.softmax(logits, dim=1)
print('probs', type(probs), probs.shape)
return probs.detach().cpu().numpy()
explainer = lime_image.LimeImageExplainer()
explanation = explainer.explain_instance(
input_image,
classifier_fn,
top_labels=top_labels,
hide_color=0,
num_samples=num_samples,
num_features=num_features,
batch_size=batch_size)
temp, mask = explanation.get_image_and_mask(
explanation.top_labels[0],
positive_only=False, num_features=num_features, hide_rest=False)
lime_output = mark_boundaries(temp/255.0, mask)
return lime_output, top_10_classes
def fetch_model_names():
return models.list_models(module=torchvision.models)
def fetch_model(model_name):
print('Retrieving model ', model_name)
weights_enum = models.get_model_weights(model_name)
for w in weights_enum:
if "IMAGENET1K" in w.name:
weights = w
model = models.get_model(model_name, weights=weights)
print('Model weights loaded', w.name)
return model, weights
return None, None
with gd.Blocks() as demo:
with gd.Column():
gd.Markdown(value='''
# A simple GUI for LIME
This is a simple GUI for Local Interpretable Model-agnostic Explanations (LIME). It allows you to run LIME on a variety of models and images. I've used the following resources to build this GUI:
* [LIME](https://github.com/marcotcr/lime)
* [LIME tutorial](https://github.com/marcotcr/lime/blob/master/tutorials/lime_image.ipynb)
''')
with gd.Row():
input_image = gd.Image(label="Input Image. Please upload an image that you want LIME to explain")
with gd.Column():
model_name = gd.Dropdown(label="Model",
info='''
Select the image classification model to use for LIME.
The list is automatically populated by using torchvision library.
''',
value='convnext_tiny',
choices=fetch_model_names())
top_labels = gd.Number(label='top_labels',info='''
use the first <top_labels> labels to create explanations.
For example, setting top_labels=5 will create explanations
for the top 5 most likely classes.''',
precision=0, value=5)
num_samples = gd.Number(label="num_samples",
info="How many samples to be created to build the linear model inside LIME",
precision=0, value=100)
with gd.Column():
num_features = gd.Number(label="num_features",
info='Among the most important superpixels (features), how many to be shown in the explanation image',
precision=0, value=2)
batch_size = gd.Number(label="batch_size",
info='how many images in the samples to be processed at once',
precision=0, value=20)
run_button = gd.Button(label="Run")
with gd.Row():
top_10_classes = gd.DataFrame(label="Top 10 classes",
info="Top-10 classes for the input image calculated by using the selected model",
headers=["class_id","label","probability"],
datatype=["number","str","number"])
lime_output = gd.Image(label="Lime Explanation",
info="The explanation image for the input image calculated by LIME for the selected model")
gd.Examples(
label="Some examples images and parameters",
examples=[["jeep.png","convnext_tiny",5,100,2,20],
["IMG_0154.jpg","convnext_tiny",5,100,2,20],
["IMG_0155.jpg","convnext_tiny",5,100,2,20],
["IMG_0156.jpg","convnext_tiny",5,100,2,20],
["IMG_0157.jpg","convnext_tiny",5,100,2,20],
["IMG_0158.jpg","convnext_tiny",5,100,2,20],
["IMG_0159.jpg","convnext_tiny",5,100,2,20],
["IMG_0160.jpg","convnext_tiny",5,100,2,20]],
inputs=[input_image,model_name,top_labels,num_samples,num_features,batch_size])
run_button.click(fn=run_lime,inputs=[input_image, model_name, top_labels,num_samples,num_features,batch_size],
outputs=[lime_output,top_10_classes])
if __name__ == "__main__":
demo.launch()