Spaces:
Runtime error
Runtime error
| import sys | |
| import PIL | |
| import cv2 | |
| import torch | |
| import torchvision | |
| import torch.nn as nn | |
| from utils.save_load import load_model | |
| import gradio as gr | |
| from PIL import Image | |
| from torchvision import transforms | |
| import gradio as gr | |
| from pytorch_grad_cam import GradCAM, AblationCAM, FullGrad, EigenGradCAM, LayerCAM | |
| from pytorch_grad_cam.utils.image import show_cam_on_image | |
| from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
| from pytorch_grad_cam import DeepFeatureFactorization | |
| from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image, deprocess_image | |
| import numpy as np | |
| from typing import List | |
| from matplotlib import pyplot as plt | |
| from matplotlib.lines import Line2D | |
| labels = [ | |
| "Achaemenid architecture", | |
| "American craftsman style", | |
| "American Foursquare architecture", | |
| "Ancient Egyptian architecture", | |
| "Art Deco architecture", | |
| "Art Nouveau architecture", | |
| "Baroque architecture", | |
| "Bauhaus architecture", | |
| "Beaux-Arts architecture", | |
| "Brutalism architecture", | |
| "Byzantine architecture", | |
| "Chicago school architecture", | |
| "Colonial architecture", | |
| "Deconstructivism", | |
| "Edwardian architecture", | |
| "Georgian architecture", | |
| "Gothic architecture", | |
| "Greek Revival architecture", | |
| "International style", | |
| "Islamic architecture", | |
| "Novelty architecture", | |
| "Palladian architecture", | |
| "Postmodern architecture", | |
| "Queen Anne architecture", | |
| "Romanesque architecture", | |
| "Russian Revival architecture", | |
| "Tudor Revival architecture" | |
| ] | |
| print(len(labels)) | |
| model = torchvision.models.efficientnet_v2_l() | |
| model.classifier = nn.Sequential( | |
| nn.Dropout(p=0.4, inplace=True), | |
| nn.Linear(1280, len(labels), bias=True) | |
| ) | |
| load_model(model) | |
| target_layers = model.features[-1] | |
| classifier = model.classifier | |
| cam = LayerCAM(model=model, target_layers=target_layers, use_cuda=False) | |
| dff = DeepFeatureFactorization( | |
| model=model, target_layer=target_layers, computation_on_concepts=classifier) | |
| def show_factorization_on_image(img: np.ndarray, | |
| explanations: np.ndarray, | |
| colors: List[np.ndarray] = None, | |
| image_weight: float = 0.5, | |
| concept_labels: List = None) -> np.ndarray: | |
| n_components = explanations.shape[0] | |
| if colors is None: | |
| # taken from https://github.com/edocollins/DFF/blob/master/utils.py | |
| _cmap = plt.cm.get_cmap('gist_rainbow') | |
| colors = [ | |
| np.array( | |
| _cmap(i)) for i in np.arange( | |
| 0, | |
| 1, | |
| 1.0 / | |
| n_components)] | |
| concept_per_pixel = explanations.argmax(axis=0) | |
| masks = [] | |
| for i in range(n_components): | |
| mask = np.zeros(shape=(img.shape[0], img.shape[1], 3)) | |
| mask[:, :, :] = colors[i][:3] | |
| explanation = explanations[i] | |
| explanation[concept_per_pixel != i] = 0 | |
| mask = np.uint8(mask * 255) | |
| mask = cv2.cvtColor(mask, cv2.COLOR_RGB2HSV) | |
| mask[:, :, 2] = np.uint8(255 * explanation) | |
| mask = cv2.cvtColor(mask, cv2.COLOR_HSV2RGB) | |
| mask = np.float32(mask) / 255 | |
| masks.append(mask) | |
| mask = np.sum(np.float32(masks), axis=0) | |
| result = img * image_weight + mask * (1 - image_weight) | |
| result = np.uint8(result * 255) | |
| if concept_labels is not None: | |
| px = 1 / plt.rcParams['figure.dpi'] # pixel in inches | |
| fig = plt.figure(figsize=(result.shape[1] * px, result.shape[0] * px)) | |
| plt.rcParams['legend.fontsize'] = 6 * result.shape[0] / 256 | |
| lw = 5 * result.shape[0] / 256 | |
| lines = [Line2D([0], [0], color=colors[i], lw=lw) | |
| for i in range(n_components)] | |
| plt.legend(lines, | |
| concept_labels, | |
| fancybox=False, | |
| shadow=False, | |
| frameon=False, | |
| loc="center") | |
| plt.tight_layout(pad=0, w_pad=0, h_pad=0) | |
| plt.axis('off') | |
| fig.canvas.draw() | |
| data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) | |
| plt.close(fig=fig) | |
| data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) | |
| data = cv2.resize(data, (result.shape[1], result.shape[0])) | |
| result = np.vstack((result, data)) | |
| return result | |
| def create_labels(concept_scores, top_k=2): | |
| """ Create a list with the image-net category names of the top scoring categories""" | |
| concept_categories = np.argsort(concept_scores, axis=1)[:, ::-1][:, :top_k] | |
| concept_labels_topk = [] | |
| for concept_index in range(concept_categories.shape[0]): | |
| categories = concept_categories[concept_index, :] | |
| concept_labels = [] | |
| for category in categories: | |
| score = concept_scores[concept_index, category] | |
| label = f"{labels[category].split(',')[0]}:{score*100:.2f}%" | |
| concept_labels.append(label) | |
| concept_labels_topk.append("\n".join(concept_labels)) | |
| return concept_labels_topk | |
| def predict(rgb_img, top_k): | |
| print(top_k) | |
| inp_01 = transforms.Compose( | |
| [ | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.4937, 0.5060, 0.5030], [ | |
| 0.2705, 0.2653, 0.2998]), | |
| transforms.Resize((224, 224)), | |
| ])(rgb_img) | |
| model.eval() | |
| with torch.no_grad(): | |
| prediction = torch.nn.functional.softmax( | |
| model(inp_01.unsqueeze(0))[0], dim=0) | |
| confidences = {labels[i]: float(prediction[i]) | |
| for i in range(len(labels))} | |
| concepts, batch_explanations, concept_outputs = dff( | |
| inp_01.unsqueeze(0), 5) | |
| concept_outputs = torch.softmax( | |
| torch.from_numpy(concept_outputs), axis=-1).numpy() | |
| concept_label_strings = create_labels(concept_outputs, top_k=top_k) | |
| print(inp_01.shape) | |
| print(batch_explanations[0].shape) | |
| res = cv2.resize(np.transpose( | |
| batch_explanations[0], (1, 2, 0)), (rgb_img.size[0], rgb_img.size[1])) | |
| res = np.transpose(res, (2, 0, 1)) | |
| print(res.shape) | |
| visualization_01 = show_factorization_on_image(np.float32(rgb_img)/255.0, | |
| res, | |
| image_weight=0.3, | |
| concept_labels=concept_label_strings) | |
| return confidences, visualization_01, | |
| gr.Interface(fn=predict, | |
| inputs=[gr.Image(type="pil"), gr.Slider( | |
| minimum=1, maximum=4, label="Number of top results", step=1)], | |
| outputs=[gr.Label(num_top_classes=5), "image"], | |
| examples=[["./assets/bauhaus.jpg", 1], | |
| ["./assets/frank_gehry.jpg", 2], ["./assets/pyramid.jpg", 3]] | |
| ).launch() | |
| # examples=["./assets/bauhaus.jpg", "./assets/frank_gehry.jpg", "./assets/pyramid.jpg"] | |