xAI2023 / app.py
hkayabilisim's picture
Fixed normalization issue for EUROSAT_CUSTOM_MODEL
45c81bf
import gradio as gd
import numpy as np
import os
import torch
import torchvision
import torchvision.models as models
from lime import lime_image
import matplotlib.pyplot as plt
import matplotlib
import torch.nn.functional as F
from skimage.segmentation import mark_boundaries
from PIL import Image
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
import wget
import cv2
matplotlib.use('agg')
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms, models
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchvision.transforms import v2
# Vanilla Legendre between [0,1]
def Pn(m, x):
if m == 0:
return np.ones_like(x)
elif m == 1:
return x
else:
return (2*m-1)*x*Pn(m-1, x)/m - (m-1)*Pn(m-2, x)/m
# Legendre between [a,b]
def L(a,b,m,x):
return np.sqrt((2*m+1)/(b-a))*Pn(m, 2*(x-b)/(b-a)+1)
eurosat_transform = transforms.Compose([
transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[255.0, 255.0, 255.0]), # normalize to [0,1] first
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize images
])
class CNN(nn.Module):
def __init__(self, num_classes=10): # Modify num_classes based on the number of your classes
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.fc1 = nn.Linear(64 * 16 * 16, 512)
self.fc2 = nn.Linear(512, num_classes)
self.dropout = nn.Dropout(0.25)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 64 * 16 * 16)
x = self.dropout(x)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
def run_lime(input_image,
model_name: str,
top_labels: int,
num_samples: int,
num_features: int,
batch_size: int):
# input_image is a numpy array of shape (height, width, channels)
# range is [0, 255]
print('model_name', model_name)
print('top_labels', top_labels)
print('num_samples', num_samples)
print('num_features', num_features)
print('batch_size', batch_size)
print('input image', type(input_image), input_image.shape)
model, weights, preprocess, names = fetch_model(model_name)
input_image_processed = preprocess(torch.from_numpy(input_image.astype(np.float32).transpose(2,0,1))).unsqueeze(0)
logits = model(input_image_processed)
probs = F.softmax(logits, dim=1)
top_10_classes = []
print('probs', type(probs), probs.shape)
for x in probs.argsort(descending=True)[0][:10]:
print(x.item(), names[x], probs[0,x].item())
top_10_classes.append([x.item(), names[x], f'{probs[0,x].item():.4f}'])
def classifier_fn(images):
print('classifier_fn', type(images), images.shape)
zz = preprocess(torch.from_numpy(images[0].transpose(2,0,1).astype(np.float32)))
c, w, h = zz.shape
batch = torch.zeros(batch_size, c, w, h)
print('len(images)', len(images))
for i in range(batch_size):
batch[i] = preprocess(torch.from_numpy(images[i].transpose(2,0,1).astype(np.float32)))
print('batch', type(batch), batch.shape)
logits = model(batch)
probs = F.softmax(logits, dim=1)
print('probs', type(probs), probs.shape)
return probs.detach().cpu().numpy()
explainer = lime_image.LimeImageExplainer()
explanation = explainer.explain_instance(
input_image,
classifier_fn,
top_labels=top_labels,
hide_color=0,
num_samples=num_samples,
num_features=num_features,
batch_size=batch_size)
temp, mask = explanation.get_image_and_mask(
explanation.top_labels[0],
positive_only=False, num_features=num_features, hide_rest=False)
lime_output = mark_boundaries(temp/255.0, mask)
return lime_output, top_10_classes
def segmented_image(img, masks, alpha=0.7):
segment_image = img.copy()
for mask in masks:
segment_image[mask['segmentation'] == 1] = 255*np.random.random(3)
cv2.addWeighted(segment_image, alpha, img, 1.0-alpha, 0, segment_image)
return segment_image
def segment_heatmap_image(img, masks, mask_weights, num_features_hdmr):
w, h, c = img.shape
img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
# increase brightness of gray image
img_gray = cv2.convertScaleAbs(img_gray, alpha=10, beta=0)
img_grad3d = np.dstack([img_gray, img_gray, img_gray])
print(img.shape, img_gray.shape)
segment_image = np.zeros((w,h)).astype(np.uint8)
important_segment_indices = mask_weights.argsort()[-num_features_hdmr:]
for i in important_segment_indices:
mask = masks[i]
weight = mask_weights[i]
segment_image[mask['segmentation'] == True] = int(255*weight)
heatmap_img = cv2.applyColorMap(segment_image, cv2.COLORMAP_JET)
super_imposed_img = cv2.addWeighted(heatmap_img, 1, img_grad3d, 0.6, 0)
return super_imposed_img, heatmap_img
def sobol(x, y, m):
print(x.shape, y.shape)
N, n = x.shape
f0 = np.mean(y)
alpha = np.zeros((m, n))
for r in range(m):
for i in range(n):
alpha[r, i] = np.mean((y-f0) * L(0, 1, r+1, np.array(x[:, i])))
global_D = np.mean(y ** 2) - np.mean(y) ** 2
D_first_order = np.zeros((n,m))
S_first_order = np.zeros((n,m))
for degree in range(m):
for k in range(n):
D_first_order[k,degree] = sum(alpha[r,k] ** 2 for r in range(degree+1))
S_first_order[k,degree] = D_first_order[k,degree]/global_D
return S_first_order
def run_hdmr(input_image,
model_name: str,
sam_model_name: str,
num_samples_hdmr: int,
num_legendre: int,
num_features_hdmr: int):
# input_image is a numpy array of shape (height, width, channels)
# range is [0, 255]
print('model_name', model_name)
print('sam_model_name', sam_model_name)
print('num_samples_hdmr', num_samples_hdmr)
print('num_features_hdmr', num_features_hdmr)
print('input image', type(input_image), input_image.shape)
model, weights, preprocess, names = fetch_model(model_name)
sam_model = fetch_sam_model(sam_model_name)
mask_generator = SamAutomaticMaskGenerator(sam_model)
masks = mask_generator.generate(input_image)
sam_segmented_image = segmented_image(input_image, masks, alpha=0.9)
batch = preprocess(torch.from_numpy(input_image.astype(np.float32).transpose(2,0,1))).unsqueeze(0)
# Unit normalize the logits
logits = model(batch)
logits = logits[0].detach().numpy()
logits_length = np.linalg.norm(logits)
logits_normalized = logits / logits_length
print('logits_normalized',logits_normalized.shape)
N = num_samples_hdmr
n = len(masks)
x = np.random.rand(N, n)
y = np.zeros((3,N)) # cosine, l1, l2
# TODO: implement batch_size
for sample in range(N):
x_input = input_image.copy()
for i, mask in enumerate(masks):
x_seg = mask['segmentation']
x_input[x_seg == 1] = x_input[x_seg == 1] * np.power(x[sample,i],2)
batch = preprocess(torch.from_numpy(x_input.astype(np.float32).transpose(2,0,1))).unsqueeze(0)
# Unit normalize the logits
logits_sample = model(batch)
probs = logits_sample.squeeze(0).softmax(0)
logits_sample = logits_sample[0].detach().numpy()
logits_sample_length = np.linalg.norm(logits_sample)
logits_sample_normalized = logits_sample / logits_sample_length
cosine_distance = np.dot(logits_normalized, logits_sample_normalized)
l1_distance = np.sum(np.abs(logits_normalized - logits_sample_normalized))
l2_distance = np.linalg.norm(logits_normalized - logits_sample_normalized)
class_id = probs.argmax().item()
score = probs[class_id].item()
category_name = names[class_id]
print(f"sample:{sample:2d} cosine: {cosine_distance:.5f} l1: {l1_distance:.5f} l2: {l2_distance:.5f} {category_name}: {100 * score:.1f}%")
y[:,sample] = [cosine_distance, l1_distance, l2_distance]
sobol_indices_cosine = sobol(x, y[0], num_legendre)
sobol_indices_l1 = sobol(x, y[1], num_legendre)
sobol_indices_l2 = sobol(x, y[2], num_legendre)
hdmr_indices_cosine_df = [[f'{b:.4f}' for b in a] for a in sobol_indices_cosine]
hdmr_indices_l1_df = [[f'{b:.4f}' for b in a] for a in sobol_indices_l1]
hdmr_indices_l2_df = [[f'{b:.4f}' for b in a] for a in sobol_indices_l2]
weight_cosine = sobol_indices_cosine[:,-1] / np.max(sobol_indices_cosine[:,-1])
weight_l1 = sobol_indices_l1[:,-1] / np.max(sobol_indices_l1[:,-1])
weight_l2 = sobol_indices_l2[:,-1] / np.max(sobol_indices_l2[:,-1])
print('weight_cosine',weight_cosine.shape)
_, hdmr_cosine = segment_heatmap_image(input_image, masks, weight_cosine,num_features_hdmr)
_, hdmr_l1 = segment_heatmap_image(input_image, masks, weight_l1,num_features_hdmr)
_, hdmr_l2, = segment_heatmap_image(input_image, masks, weight_l2,num_features_hdmr)
return hdmr_cosine, hdmr_l1, hdmr_l2, hdmr_indices_cosine_df, hdmr_indices_l1_df, hdmr_indices_l2_df,sam_segmented_image
def fetch_sam_model(sam_model_name_checkpoint):
sam_model_name, sam_checkpoint = sam_model_name_checkpoint.split(' ')
URL = f"https://dl.fbaipublicfiles.com/segment_anything/{sam_checkpoint}"
if not os.path.isfile(sam_checkpoint):
response = wget.download(URL, sam_checkpoint)
sam = sam_model_registry[sam_model_name](checkpoint=sam_checkpoint)
return sam
def fetch_model_names():
model_names = models.list_models(module=torchvision.models)
return ['EUROSAT_CUSTOM_MODEL'] + model_names
def fetch_model(model_name):
print('Retrieving model ', model_name)
if model_name == "EUROSAT_CUSTOM_MODEL":
model = CNN()
weights = torch.load('EUROSAT_CUSTOM_MODEL.pth')
model.load_state_dict(weights)
return (model, weights, eurosat_transform,
['AnnualCrop','Forest','HerbaceousVegetation','Highway',
'Industrial','Pasture','PermanentCrop','Residential',
'River','SeaLake'])
weights_enum = models.get_model_weights(model_name)
for w in weights_enum:
if "IMAGENET1K" in w.name:
weights = w
model = models.get_model(model_name, weights=weights)
print('Model weights loaded', w.name)
return (model, weights,
weights.transforms(antialias=True),
weights.meta['categories'])
return None, None, None, None
with gd.Blocks() as demo:
with gd.Column():
gd.Markdown(value='''
# xAI with a Meta-Modelling Algorithm
And its comparison with LIME.
LIME implementation is based on:
* [LIME](https://github.com/marcotcr/lime)
* [LIME tutorial](https://github.com/marcotcr/lime/blob/master/tutorials/lime_image.ipynb)
''')
with gd.Row():
with gd.Column():
input_image = gd.Image(label="Input Image. Please upload an image that you want LIME to explain")
model_name = gd.Dropdown(label="Model",
info='''
Select the image classification model to use for LIME.
The list is automatically populated by using torchvision library.
''',
value='EUROSAT_CUSTOM_MODEL',
choices=fetch_model_names())
sam_model_name = gd.Dropdown(label="SAM model",
info='Select the SAM model',
value='vit_b sam_vit_b_01ec64.pth',
choices=['vit_b sam_vit_b_01ec64.pth'])
with gd.Column():
top_labels = gd.Number(label='top_labels',info='''
use the first <top_labels> labels to create explanations.
For example, setting top_labels=5 will create explanations
for the top 5 most likely classes.''',
precision=0, value=5)
num_samples = gd.Number(label="num_samples",
info="How many samples to be created to build the linear model inside LIME",
precision=0, value=100)
num_features = gd.Number(label="num_features",
info='Among the most important superpixels (features), how many to be shown in the explanation image',
precision=0, value=2)
batch_size = gd.Number(label="batch_size",
info='how many images in the samples to be processed at once',
precision=0, value=20)
with gd.Column():
num_samples_hdmr = gd.Number(label="num_samples_hdmr",
info="How many samples in HDMR",
precision=0, value=10)
num_legendre = gd.Number(label="num_legendre",
info='Number of Legendre Bases for HDMR',
precision=0,value=3)
num_features_hdmr = gd.Number(label="num_features_hdmr",
info='Among the most important segments, how many to be shown in the explanation image',
precision=0, value=2)
run_button = gd.Button(label="Run")
with gd.Row():
top_10_classes = gd.DataFrame(label="Top 10 classes",
info="Top-10 classes for the input image calculated by using the selected model",
headers=["class_id","label","probability"],
datatype=["number","str","number"])
lime_output = gd.Image(label="Lime Explanation",
info="The explanation image for the input image calculated by LIME for the selected model")
sam_segmented_image = gd.Image(label="SAM Segmentation",
info="The segmentation image for the input image calculated by SAM")
with gd.Row():
hdmr_cosine = gd.Image(label="HDMR Explanation via Cosine Distance")
hdmr_l1 = gd.Image(label="HDMR Explanation via L1 Distance")
hdmr_l2 = gd.Image(label="HDMR Explanation via L2 Distance")
with gd.Row():
hdmr_cosine_indices = gd.DataFrame(label="HDMR Cosine Indices")
hdmr_l1_indices = gd.DataFrame(label="HDMR L1 Indices")
hdmr_l2_indices = gd.DataFrame(label="HDMR L2 Indices")
gd.Examples(
label="Some examples images and parameters",
examples=[["jeep.png","convnext_tiny",5,20,2,20],
["IMG_0154.jpg","convnext_tiny",5,100,2,20],
["IMG_0155.jpg","convnext_tiny",5,100,2,20],
["IMG_0156.jpg","convnext_tiny",5,100,2,20],
["IMG_0157.jpg","convnext_tiny",5,100,2,20],
["IMG_0158.jpg","convnext_tiny",5,100,2,20],
["IMG_0159.jpg","convnext_tiny",5,100,2,20],
["IMG_0160.jpg","convnext_tiny",5,100,2,20]],
inputs=[input_image,model_name,top_labels,num_samples,num_features,batch_size])
run_button.click(fn=run_lime,inputs=[input_image, model_name, top_labels,num_samples,num_features,batch_size],
outputs=[lime_output,top_10_classes])
run_button.click(fn=run_hdmr,inputs=[input_image, model_name, sam_model_name, num_samples_hdmr, num_legendre, num_features_hdmr],
outputs=[hdmr_cosine, hdmr_l1, hdmr_l2,
hdmr_cosine_indices, hdmr_l1_indices, hdmr_l2_indices,
sam_segmented_image])
if __name__ == "__main__":
demo.launch()