Spaces:
Build error
Build error
| import gradio as gd | |
| import numpy as np | |
| import os | |
| import torch | |
| import torchvision | |
| import torchvision.models as models | |
| from lime import lime_image | |
| import matplotlib.pyplot as plt | |
| import matplotlib | |
| import torch.nn.functional as F | |
| from skimage.segmentation import mark_boundaries | |
| from PIL import Image | |
| from segment_anything import SamAutomaticMaskGenerator, sam_model_registry | |
| import wget | |
| import cv2 | |
| matplotlib.use('agg') | |
| import os | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import DataLoader, random_split | |
| from torchvision import datasets, transforms, models | |
| import torch.optim as optim | |
| import torch.nn.functional as F | |
| import matplotlib.pyplot as plt | |
| from torchvision.transforms import v2 | |
| # Vanilla Legendre between [0,1] | |
| def Pn(m, x): | |
| if m == 0: | |
| return np.ones_like(x) | |
| elif m == 1: | |
| return x | |
| else: | |
| return (2*m-1)*x*Pn(m-1, x)/m - (m-1)*Pn(m-2, x)/m | |
| # Legendre between [a,b] | |
| def L(a,b,m,x): | |
| return np.sqrt((2*m+1)/(b-a))*Pn(m, 2*(x-b)/(b-a)+1) | |
| eurosat_transform = transforms.Compose([ | |
| transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[255.0, 255.0, 255.0]), # normalize to [0,1] first | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize images | |
| ]) | |
| class CNN(nn.Module): | |
| def __init__(self, num_classes=10): # Modify num_classes based on the number of your classes | |
| super(CNN, self).__init__() | |
| self.conv1 = nn.Conv2d(3, 32, 3, padding=1) | |
| self.pool = nn.MaxPool2d(2, 2) | |
| self.conv2 = nn.Conv2d(32, 64, 3, padding=1) | |
| self.fc1 = nn.Linear(64 * 16 * 16, 512) | |
| self.fc2 = nn.Linear(512, num_classes) | |
| self.dropout = nn.Dropout(0.25) | |
| def forward(self, x): | |
| x = self.pool(F.relu(self.conv1(x))) | |
| x = self.pool(F.relu(self.conv2(x))) | |
| x = x.view(-1, 64 * 16 * 16) | |
| x = self.dropout(x) | |
| x = F.relu(self.fc1(x)) | |
| x = self.fc2(x) | |
| return x | |
| def run_lime(input_image, | |
| model_name: str, | |
| top_labels: int, | |
| num_samples: int, | |
| num_features: int, | |
| batch_size: int): | |
| # input_image is a numpy array of shape (height, width, channels) | |
| # range is [0, 255] | |
| print('model_name', model_name) | |
| print('top_labels', top_labels) | |
| print('num_samples', num_samples) | |
| print('num_features', num_features) | |
| print('batch_size', batch_size) | |
| print('input image', type(input_image), input_image.shape) | |
| model, weights, preprocess, names = fetch_model(model_name) | |
| input_image_processed = preprocess(torch.from_numpy(input_image.astype(np.float32).transpose(2,0,1))).unsqueeze(0) | |
| logits = model(input_image_processed) | |
| probs = F.softmax(logits, dim=1) | |
| top_10_classes = [] | |
| print('probs', type(probs), probs.shape) | |
| for x in probs.argsort(descending=True)[0][:10]: | |
| print(x.item(), names[x], probs[0,x].item()) | |
| top_10_classes.append([x.item(), names[x], f'{probs[0,x].item():.4f}']) | |
| def classifier_fn(images): | |
| print('classifier_fn', type(images), images.shape) | |
| zz = preprocess(torch.from_numpy(images[0].transpose(2,0,1).astype(np.float32))) | |
| c, w, h = zz.shape | |
| batch = torch.zeros(batch_size, c, w, h) | |
| print('len(images)', len(images)) | |
| for i in range(batch_size): | |
| batch[i] = preprocess(torch.from_numpy(images[i].transpose(2,0,1).astype(np.float32))) | |
| print('batch', type(batch), batch.shape) | |
| logits = model(batch) | |
| probs = F.softmax(logits, dim=1) | |
| print('probs', type(probs), probs.shape) | |
| return probs.detach().cpu().numpy() | |
| explainer = lime_image.LimeImageExplainer() | |
| explanation = explainer.explain_instance( | |
| input_image, | |
| classifier_fn, | |
| top_labels=top_labels, | |
| hide_color=0, | |
| num_samples=num_samples, | |
| num_features=num_features, | |
| batch_size=batch_size) | |
| temp, mask = explanation.get_image_and_mask( | |
| explanation.top_labels[0], | |
| positive_only=False, num_features=num_features, hide_rest=False) | |
| lime_output = mark_boundaries(temp/255.0, mask) | |
| return lime_output, top_10_classes | |
| def segmented_image(img, masks, alpha=0.7): | |
| segment_image = img.copy() | |
| for mask in masks: | |
| segment_image[mask['segmentation'] == 1] = 255*np.random.random(3) | |
| cv2.addWeighted(segment_image, alpha, img, 1.0-alpha, 0, segment_image) | |
| return segment_image | |
| def segment_heatmap_image(img, masks, mask_weights, num_features_hdmr): | |
| w, h, c = img.shape | |
| img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) | |
| # increase brightness of gray image | |
| img_gray = cv2.convertScaleAbs(img_gray, alpha=10, beta=0) | |
| img_grad3d = np.dstack([img_gray, img_gray, img_gray]) | |
| print(img.shape, img_gray.shape) | |
| segment_image = np.zeros((w,h)).astype(np.uint8) | |
| important_segment_indices = mask_weights.argsort()[-num_features_hdmr:] | |
| for i in important_segment_indices: | |
| mask = masks[i] | |
| weight = mask_weights[i] | |
| segment_image[mask['segmentation'] == True] = int(255*weight) | |
| heatmap_img = cv2.applyColorMap(segment_image, cv2.COLORMAP_JET) | |
| super_imposed_img = cv2.addWeighted(heatmap_img, 1, img_grad3d, 0.6, 0) | |
| return super_imposed_img, heatmap_img | |
| def sobol(x, y, m): | |
| print(x.shape, y.shape) | |
| N, n = x.shape | |
| f0 = np.mean(y) | |
| alpha = np.zeros((m, n)) | |
| for r in range(m): | |
| for i in range(n): | |
| alpha[r, i] = np.mean((y-f0) * L(0, 1, r+1, np.array(x[:, i]))) | |
| global_D = np.mean(y ** 2) - np.mean(y) ** 2 | |
| D_first_order = np.zeros((n,m)) | |
| S_first_order = np.zeros((n,m)) | |
| for degree in range(m): | |
| for k in range(n): | |
| D_first_order[k,degree] = sum(alpha[r,k] ** 2 for r in range(degree+1)) | |
| S_first_order[k,degree] = D_first_order[k,degree]/global_D | |
| return S_first_order | |
| def run_hdmr(input_image, | |
| model_name: str, | |
| sam_model_name: str, | |
| num_samples_hdmr: int, | |
| num_legendre: int, | |
| num_features_hdmr: int): | |
| # input_image is a numpy array of shape (height, width, channels) | |
| # range is [0, 255] | |
| print('model_name', model_name) | |
| print('sam_model_name', sam_model_name) | |
| print('num_samples_hdmr', num_samples_hdmr) | |
| print('num_features_hdmr', num_features_hdmr) | |
| print('input image', type(input_image), input_image.shape) | |
| model, weights, preprocess, names = fetch_model(model_name) | |
| sam_model = fetch_sam_model(sam_model_name) | |
| mask_generator = SamAutomaticMaskGenerator(sam_model) | |
| masks = mask_generator.generate(input_image) | |
| sam_segmented_image = segmented_image(input_image, masks, alpha=0.9) | |
| batch = preprocess(torch.from_numpy(input_image.astype(np.float32).transpose(2,0,1))).unsqueeze(0) | |
| # Unit normalize the logits | |
| logits = model(batch) | |
| logits = logits[0].detach().numpy() | |
| logits_length = np.linalg.norm(logits) | |
| logits_normalized = logits / logits_length | |
| print('logits_normalized',logits_normalized.shape) | |
| N = num_samples_hdmr | |
| n = len(masks) | |
| x = np.random.rand(N, n) | |
| y = np.zeros((3,N)) # cosine, l1, l2 | |
| # TODO: implement batch_size | |
| for sample in range(N): | |
| x_input = input_image.copy() | |
| for i, mask in enumerate(masks): | |
| x_seg = mask['segmentation'] | |
| x_input[x_seg == 1] = x_input[x_seg == 1] * np.power(x[sample,i],2) | |
| batch = preprocess(torch.from_numpy(x_input.astype(np.float32).transpose(2,0,1))).unsqueeze(0) | |
| # Unit normalize the logits | |
| logits_sample = model(batch) | |
| probs = logits_sample.squeeze(0).softmax(0) | |
| logits_sample = logits_sample[0].detach().numpy() | |
| logits_sample_length = np.linalg.norm(logits_sample) | |
| logits_sample_normalized = logits_sample / logits_sample_length | |
| cosine_distance = np.dot(logits_normalized, logits_sample_normalized) | |
| l1_distance = np.sum(np.abs(logits_normalized - logits_sample_normalized)) | |
| l2_distance = np.linalg.norm(logits_normalized - logits_sample_normalized) | |
| class_id = probs.argmax().item() | |
| score = probs[class_id].item() | |
| category_name = names[class_id] | |
| print(f"sample:{sample:2d} cosine: {cosine_distance:.5f} l1: {l1_distance:.5f} l2: {l2_distance:.5f} {category_name}: {100 * score:.1f}%") | |
| y[:,sample] = [cosine_distance, l1_distance, l2_distance] | |
| sobol_indices_cosine = sobol(x, y[0], num_legendre) | |
| sobol_indices_l1 = sobol(x, y[1], num_legendre) | |
| sobol_indices_l2 = sobol(x, y[2], num_legendre) | |
| hdmr_indices_cosine_df = [[f'{b:.4f}' for b in a] for a in sobol_indices_cosine] | |
| hdmr_indices_l1_df = [[f'{b:.4f}' for b in a] for a in sobol_indices_l1] | |
| hdmr_indices_l2_df = [[f'{b:.4f}' for b in a] for a in sobol_indices_l2] | |
| weight_cosine = sobol_indices_cosine[:,-1] / np.max(sobol_indices_cosine[:,-1]) | |
| weight_l1 = sobol_indices_l1[:,-1] / np.max(sobol_indices_l1[:,-1]) | |
| weight_l2 = sobol_indices_l2[:,-1] / np.max(sobol_indices_l2[:,-1]) | |
| print('weight_cosine',weight_cosine.shape) | |
| _, hdmr_cosine = segment_heatmap_image(input_image, masks, weight_cosine,num_features_hdmr) | |
| _, hdmr_l1 = segment_heatmap_image(input_image, masks, weight_l1,num_features_hdmr) | |
| _, hdmr_l2, = segment_heatmap_image(input_image, masks, weight_l2,num_features_hdmr) | |
| return hdmr_cosine, hdmr_l1, hdmr_l2, hdmr_indices_cosine_df, hdmr_indices_l1_df, hdmr_indices_l2_df,sam_segmented_image | |
| def fetch_sam_model(sam_model_name_checkpoint): | |
| sam_model_name, sam_checkpoint = sam_model_name_checkpoint.split(' ') | |
| URL = f"https://dl.fbaipublicfiles.com/segment_anything/{sam_checkpoint}" | |
| if not os.path.isfile(sam_checkpoint): | |
| response = wget.download(URL, sam_checkpoint) | |
| sam = sam_model_registry[sam_model_name](checkpoint=sam_checkpoint) | |
| return sam | |
| def fetch_model_names(): | |
| model_names = models.list_models(module=torchvision.models) | |
| return ['EUROSAT_CUSTOM_MODEL'] + model_names | |
| def fetch_model(model_name): | |
| print('Retrieving model ', model_name) | |
| if model_name == "EUROSAT_CUSTOM_MODEL": | |
| model = CNN() | |
| weights = torch.load('EUROSAT_CUSTOM_MODEL.pth') | |
| model.load_state_dict(weights) | |
| return (model, weights, eurosat_transform, | |
| ['AnnualCrop','Forest','HerbaceousVegetation','Highway', | |
| 'Industrial','Pasture','PermanentCrop','Residential', | |
| 'River','SeaLake']) | |
| weights_enum = models.get_model_weights(model_name) | |
| for w in weights_enum: | |
| if "IMAGENET1K" in w.name: | |
| weights = w | |
| model = models.get_model(model_name, weights=weights) | |
| print('Model weights loaded', w.name) | |
| return (model, weights, | |
| weights.transforms(antialias=True), | |
| weights.meta['categories']) | |
| return None, None, None, None | |
| with gd.Blocks() as demo: | |
| with gd.Column(): | |
| gd.Markdown(value=''' | |
| # xAI with a Meta-Modelling Algorithm | |
| And its comparison with LIME. | |
| LIME implementation is based on: | |
| * [LIME](https://github.com/marcotcr/lime) | |
| * [LIME tutorial](https://github.com/marcotcr/lime/blob/master/tutorials/lime_image.ipynb) | |
| ''') | |
| with gd.Row(): | |
| with gd.Column(): | |
| input_image = gd.Image(label="Input Image. Please upload an image that you want LIME to explain") | |
| model_name = gd.Dropdown(label="Model", | |
| info=''' | |
| Select the image classification model to use for LIME. | |
| The list is automatically populated by using torchvision library. | |
| ''', | |
| value='EUROSAT_CUSTOM_MODEL', | |
| choices=fetch_model_names()) | |
| sam_model_name = gd.Dropdown(label="SAM model", | |
| info='Select the SAM model', | |
| value='vit_b sam_vit_b_01ec64.pth', | |
| choices=['vit_b sam_vit_b_01ec64.pth']) | |
| with gd.Column(): | |
| top_labels = gd.Number(label='top_labels',info=''' | |
| use the first <top_labels> labels to create explanations. | |
| For example, setting top_labels=5 will create explanations | |
| for the top 5 most likely classes.''', | |
| precision=0, value=5) | |
| num_samples = gd.Number(label="num_samples", | |
| info="How many samples to be created to build the linear model inside LIME", | |
| precision=0, value=100) | |
| num_features = gd.Number(label="num_features", | |
| info='Among the most important superpixels (features), how many to be shown in the explanation image', | |
| precision=0, value=2) | |
| batch_size = gd.Number(label="batch_size", | |
| info='how many images in the samples to be processed at once', | |
| precision=0, value=20) | |
| with gd.Column(): | |
| num_samples_hdmr = gd.Number(label="num_samples_hdmr", | |
| info="How many samples in HDMR", | |
| precision=0, value=10) | |
| num_legendre = gd.Number(label="num_legendre", | |
| info='Number of Legendre Bases for HDMR', | |
| precision=0,value=3) | |
| num_features_hdmr = gd.Number(label="num_features_hdmr", | |
| info='Among the most important segments, how many to be shown in the explanation image', | |
| precision=0, value=2) | |
| run_button = gd.Button(label="Run") | |
| with gd.Row(): | |
| top_10_classes = gd.DataFrame(label="Top 10 classes", | |
| info="Top-10 classes for the input image calculated by using the selected model", | |
| headers=["class_id","label","probability"], | |
| datatype=["number","str","number"]) | |
| lime_output = gd.Image(label="Lime Explanation", | |
| info="The explanation image for the input image calculated by LIME for the selected model") | |
| sam_segmented_image = gd.Image(label="SAM Segmentation", | |
| info="The segmentation image for the input image calculated by SAM") | |
| with gd.Row(): | |
| hdmr_cosine = gd.Image(label="HDMR Explanation via Cosine Distance") | |
| hdmr_l1 = gd.Image(label="HDMR Explanation via L1 Distance") | |
| hdmr_l2 = gd.Image(label="HDMR Explanation via L2 Distance") | |
| with gd.Row(): | |
| hdmr_cosine_indices = gd.DataFrame(label="HDMR Cosine Indices") | |
| hdmr_l1_indices = gd.DataFrame(label="HDMR L1 Indices") | |
| hdmr_l2_indices = gd.DataFrame(label="HDMR L2 Indices") | |
| gd.Examples( | |
| label="Some examples images and parameters", | |
| examples=[["jeep.png","convnext_tiny",5,20,2,20], | |
| ["IMG_0154.jpg","convnext_tiny",5,100,2,20], | |
| ["IMG_0155.jpg","convnext_tiny",5,100,2,20], | |
| ["IMG_0156.jpg","convnext_tiny",5,100,2,20], | |
| ["IMG_0157.jpg","convnext_tiny",5,100,2,20], | |
| ["IMG_0158.jpg","convnext_tiny",5,100,2,20], | |
| ["IMG_0159.jpg","convnext_tiny",5,100,2,20], | |
| ["IMG_0160.jpg","convnext_tiny",5,100,2,20]], | |
| inputs=[input_image,model_name,top_labels,num_samples,num_features,batch_size]) | |
| run_button.click(fn=run_lime,inputs=[input_image, model_name, top_labels,num_samples,num_features,batch_size], | |
| outputs=[lime_output,top_10_classes]) | |
| run_button.click(fn=run_hdmr,inputs=[input_image, model_name, sam_model_name, num_samples_hdmr, num_legendre, num_features_hdmr], | |
| outputs=[hdmr_cosine, hdmr_l1, hdmr_l2, | |
| hdmr_cosine_indices, hdmr_l1_indices, hdmr_l2_indices, | |
| sam_segmented_image]) | |
| if __name__ == "__main__": | |
| demo.launch() | |