File size: 6,871 Bytes
1d12695
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9787127
1d12695
9787127
1d12695
9787127
1d12695
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9787127
1d12695
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9787127
 
 
 
 
 
1d12695
9787127
1d12695
 
9787127
 
 
 
 
1d12695
9787127
 
 
 
1d12695
 
9787127
 
 
1d12695
9787127
 
1d12695
9787127
 
1d12695
 
9787127
 
1d12695
 
9787127
 
1d12695
9787127
 
8f0fbfc
 
721bced
 
 
 
 
1d12695
 
 
9787127
1d12695
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import gradio as gd 
import numpy as np
import os
import torch
import torchvision
import torchvision.models as models
from lime import lime_image
import matplotlib.pyplot as plt
import matplotlib
import torch.nn.functional as F
from skimage.segmentation import mark_boundaries
from PIL import Image
matplotlib.use('agg')

def run_lime(input_image, 
             model_name: str, 
             top_labels: int,
             num_samples: int,
             num_features: int,
             batch_size: int):
    # input_image is a numpy array of shape (height, width, channels) 
    # range is [0, 255]
    print('model_name', model_name)
    print('top_labels', top_labels)
    print('num_samples', num_samples)
    print('num_features', num_features)
    print('batch_size', batch_size)
    print('input image', type(input_image), input_image.shape)

    model, weights = fetch_model(model_name)
    preprocess = weights.transforms(antialias=True)

    input_image_processed = preprocess(torch.from_numpy(input_image.transpose(2,0,1))).unsqueeze(0)
    logits = model(input_image_processed)
    probs = F.softmax(logits, dim=1)
    names = weights.meta['categories']

    top_10_classes = []
    print('probs', type(probs), probs.shape)
    for x in probs.argsort(descending=True)[0][:10]:
        print(x.item(), names[x], probs[0,x].item())
        top_10_classes.append([x.item(), names[x], probs[0,x].item()])

    def classifier_fn(images):
        print('classifier_fn', type(images), images.shape)
        
        zz = preprocess(torch.from_numpy(images[0].transpose(2,0,1)))
        c, w, h = zz.shape 
        batch = torch.zeros(batch_size, c, w, h)
        print('len(images)', len(images))
        for i in range(batch_size):
            batch[i] = preprocess(torch.from_numpy(images[i].transpose(2,0,1)))
        print('batch', type(batch), batch.shape)

        logits = model(batch)
        probs = F.softmax(logits, dim=1)
        print('probs', type(probs), probs.shape)
        return probs.detach().cpu().numpy()

    explainer = lime_image.LimeImageExplainer()
    explanation = explainer.explain_instance(
        input_image, 
        classifier_fn,
        top_labels=top_labels, 
        hide_color=0, 
        num_samples=num_samples,
        num_features=num_features,
        batch_size=batch_size) 
    
    temp, mask = explanation.get_image_and_mask(
        explanation.top_labels[0], 
        positive_only=False, num_features=num_features, hide_rest=False)
    lime_output = mark_boundaries(temp/255.0, mask)
    return lime_output, top_10_classes

def fetch_model_names():
    return models.list_models(module=torchvision.models)

def fetch_model(model_name):
    print('Retrieving model ', model_name)
    weights_enum = models.get_model_weights(model_name)
    for w in weights_enum:
        if "IMAGENET1K" in w.name:
            weights = w
            model = models.get_model(model_name, weights=weights)
            print('Model weights loaded', w.name)
            return model, weights
    return None, None

with gd.Blocks() as demo:
    with gd.Column():
        gd.Markdown(value='''
        # A simple GUI for LIME
        This is a simple GUI for Local Interpretable Model-agnostic Explanations (LIME). It allows you to run LIME on a variety of models and images. I've used the following resources to build this GUI:
        * [LIME](https://github.com/marcotcr/lime)
        * [LIME tutorial](https://github.com/marcotcr/lime/blob/master/tutorials/lime_image.ipynb)
        ''')
        with gd.Row():
            input_image = gd.Image(label="Input Image. Please upload an image that you want LIME to explain")
            with gd.Column():
                model_name = gd.Dropdown(label="Model", 
                                         info='''
                                         Select the image classification model to use for LIME.
                                         The list is automatically populated by using torchvision library.
                                         ''',
                                         value='convnext_tiny',
                                         choices=fetch_model_names())
                top_labels = gd.Number(label='top_labels',info='''
                        use the first <top_labels> labels to create explanations.
                        For example, setting top_labels=5 will create explanations 
                        for the top 5 most likely classes.''',
                                             precision=0, value=5)
                num_samples = gd.Number(label="num_samples", 
                                        info="How many samples to be created to build the linear model inside LIME",
                                             precision=0, value=100)
            with gd.Column():
                num_features = gd.Number(label="num_features", 
                                         info='Among the most important superpixels (features), how many to be shown in the explanation image',
                                             precision=0, value=2)
                batch_size = gd.Number(label="batch_size", 
                                       info='how many images in the samples to be processed at once',
                                             precision=0, value=20)
                run_button = gd.Button(label="Run")
        with gd.Row():
            top_10_classes = gd.DataFrame(label="Top 10 classes",
                                         info="Top-10 classes for the input image calculated by using the selected model",
                                     headers=["class_id","label","probability"],
                                     datatype=["number","str","number"])
            lime_output = gd.Image(label="Lime Explanation",
                                   info="The explanation image for the input image calculated by LIME for the selected model")
        gd.Examples(
            label="Some examples images and parameters",
            examples=[["jeep.png","convnext_tiny",5,100,2,20],
                      ["IMG_0154.jpg","convnext_tiny",5,100,2,20],
                      ["IMG_0155.jpg","convnext_tiny",5,100,2,20],
                      ["IMG_0156.jpg","convnext_tiny",5,100,2,20],
                      ["IMG_0157.jpg","convnext_tiny",5,100,2,20],
                      ["IMG_0158.jpg","convnext_tiny",5,100,2,20],
                      ["IMG_0159.jpg","convnext_tiny",5,100,2,20],
                      ["IMG_0160.jpg","convnext_tiny",5,100,2,20]],
            inputs=[input_image,model_name,top_labels,num_samples,num_features,batch_size])
        
    run_button.click(fn=run_lime,inputs=[input_image, model_name, top_labels,num_samples,num_features,batch_size],
                     outputs=[lime_output,top_10_classes]) 
  
if __name__ == "__main__":
    demo.launch()