Spaces:
Sleeping
Sleeping
| from concurrent.futures import ThreadPoolExecutor | |
| import copy | |
| import os | |
| import sys | |
| sys.path.append('src') | |
| import shutil | |
| from collections import defaultdict | |
| from functools import lru_cache | |
| import cv2 | |
| import gradio as gr | |
| import mediapy | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import deep_translator | |
| from gradio_blocks import build_video_to_camvideo | |
| from Nets import CustomResNet18 | |
| from PIL import Image, ImageDraw, ImageFont | |
| from pytorch_grad_cam import GradCAM, HiResCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad | |
| from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
| from pytorch_grad_cam.utils.image import show_cam_on_image | |
| from tqdm import tqdm | |
| from util import transform | |
| font = ImageFont.truetype("src/Roboto-Regular.ttf", 16) | |
| ffmpeg_path = shutil.which('ffmpeg') | |
| mediapy.set_ffmpeg(ffmpeg_path) | |
| IMAGE_PATH = os.path.join(os.getcwd(), 'src/examples') | |
| IMAGES_PER_ROW = 5 | |
| MAXIMAL_FRAMES = 700 | |
| BATCHES_TO_PROCESS = 15 | |
| OUTPUT_FPS = 15 | |
| MAX_OUT_FRAMES = 90 | |
| MODEL = CustomResNet18(111).eval() | |
| MODEL.load_state_dict(torch.load('src/results/models/best_model.pth', map_location=torch.device('cpu'))) | |
| LANGUAGES_TO_SELECT = { | |
| "None": None, | |
| "German": "de", | |
| "French": "fr", | |
| "Spanish": "es", | |
| "Italian": "it", | |
| "Finnish": "fi", | |
| "Ukrainian": "uk", | |
| } | |
| CAM_METHODS = { | |
| "GradCAM": GradCAM, | |
| "GradCAM++": GradCAMPlusPlus, | |
| "XGradCAM": XGradCAM, | |
| "HiResCAM": HiResCAM, | |
| "EigenCAM": EigenCAM | |
| } | |
| LAYERS = { | |
| 'layer1': MODEL.resnet.layer1, | |
| 'layer2': MODEL.resnet.layer2, | |
| 'layer3': MODEL.resnet.layer3, | |
| 'layer4': MODEL.resnet.layer4, | |
| 'all': [MODEL.resnet.layer1, MODEL.resnet.layer2, MODEL.resnet.layer3, MODEL.resnet.layer4], | |
| 'layer3+4': [MODEL.resnet.layer3, MODEL.resnet.layer4] | |
| } | |
| CV2_COLORMAPS = { | |
| "Autumn": cv2.COLORMAP_AUTUMN, | |
| "Bone": cv2.COLORMAP_BONE, | |
| "Jet": cv2.COLORMAP_JET, | |
| "Winter": cv2.COLORMAP_WINTER, | |
| "Rainbow": cv2.COLORMAP_RAINBOW, | |
| "Ocean": cv2.COLORMAP_OCEAN, | |
| "Summer": cv2.COLORMAP_SUMMER, | |
| "Pink": cv2.COLORMAP_PINK, | |
| "Hot": cv2.COLORMAP_HOT, | |
| "Magma": cv2.COLORMAP_MAGMA, | |
| "Inferno": cv2.COLORMAP_INFERNO, | |
| "Plasma": cv2.COLORMAP_PLASMA, | |
| "Twilight": cv2.COLORMAP_TWILIGHT, | |
| } | |
| # cam_model = copy.deepcopy(model) | |
| data_df = pd.read_csv('src/cache/val_df.csv') | |
| C_NUM_TO_NAME = data_df[['encoded_target', 'target']].drop_duplicates().sort_values('encoded_target').set_index('encoded_target')['target'].to_dict() | |
| C_NAME_TO_NUM = {v: k for k, v in C_NUM_TO_NAME.items()} | |
| ALL_CLASSES = sorted(list(C_NUM_TO_NAME.values()), key=lambda x: x.lower()) | |
| def get_class_name(idx): | |
| return C_NUM_TO_NAME[idx] | |
| def get_class_idx(name): | |
| return C_NAME_TO_NUM[name] | |
| def get_translated(to_translate, target_language="German"): | |
| target_language = LANGUAGES_TO_SELECT[target_language] if target_language in LANGUAGES_TO_SELECT else target_language | |
| if target_language == "en": return to_translate | |
| if target_language not in LANGUAGES_TO_SELECT.values(): raise gr.Error(f'Language {target_language} not found.') | |
| try: | |
| return deep_translator.GoogleTranslator(source="en", target=target_language).translate(to_translate) | |
| except deep_translator.exceptions.TooManyRequests: | |
| print(f'Too many requests for {to_translate} to {target_language}.') | |
| return ("-/-") | |
| with ThreadPoolExecutor(max_workers=30) as executor: | |
| # give the executor the list of images and args (in this case, the target language) | |
| # and let the executor map the function to the list of images | |
| for language in tqdm(LANGUAGES_TO_SELECT.keys(), desc='Preloading translations'): | |
| executor.map(get_translated, ALL_CLASSES, [language] * len(ALL_CLASSES)) | |
| def infer_image(image, target_language): | |
| if image is None: raise gr.Error("Please upload an image.") | |
| image.save('src/results/infer_image.png') | |
| image = transform(image) | |
| image = image.unsqueeze(0) | |
| with torch.no_grad(): | |
| output = MODEL(image) | |
| distribution = torch.nn.functional.softmax(output, dim=1) | |
| ret = defaultdict(float) | |
| for idx, prob in enumerate(distribution[0]): | |
| animal = f'{get_class_name(idx)}' | |
| if target_language is not None and target_language != "None": | |
| animal += f' ({get_translated(get_class_name(idx), target_language)})' | |
| ret[animal] = prob.item() | |
| return ret | |
| def gradcam(image, colormap="Jet", use_eigen_smooth=False, use_aug_smooth=False, BWHighlight=False, alpha=0.5, cam_method=GradCAM, layer=None, specific_class="Predicted Class", label_image=True, target_lang="German"): | |
| if image is None: | |
| raise gr.Error("Please upload an image.") | |
| if isinstance(image, dict): | |
| # Its the image and a mask as pillow both -> Combine them to one image | |
| image = Image.blend(image["image"], image["mask"], alpha=0.5) | |
| if colormap not in CV2_COLORMAPS.keys(): | |
| raise gr.Error(f"Colormap {colormap} not found in {list(CV2_COLORMAPS.keys())}.") | |
| else: | |
| colormap = CV2_COLORMAPS[colormap] | |
| image_width, image_height = image.size | |
| if image_width > 6000 or image_height > 6000: | |
| raise gr.Error("The image is too big. The maximal size is 6000x6000.") | |
| MODEL.eval() | |
| layers = LAYERS[layer] | |
| image_tensor = transform(image).unsqueeze(0) | |
| targets = [ClassifierOutputTarget(get_class_idx(specific_class))] if specific_class != "Predicted Class" else None | |
| with CAM_METHODS[cam_method](model=MODEL, target_layers=layers) as cam: | |
| grayscale_cam = cam(input_tensor=image_tensor, targets=targets, aug_smooth=use_aug_smooth, eigen_smooth=use_eigen_smooth) | |
| if label_image: | |
| predicted_animal = get_class_name(np.argmax(MODEL.output.cpu().data.numpy(), axis=-1)[0]) | |
| grayscale_cam = grayscale_cam[0, :] | |
| grayscale_cam = cv2.resize(grayscale_cam, (image_width, image_height), interpolation=cv2.INTER_CUBIC) | |
| image = np.float32(image) | |
| visualization = None | |
| if BWHighlight: | |
| image = image * grayscale_cam[..., np.newaxis] | |
| visualization = image.astype(np.uint8) | |
| else: | |
| image = image / 255 | |
| visualization = show_cam_on_image(image, grayscale_cam, use_rgb=True, image_weight=alpha, colormap=colormap) | |
| if label_image: | |
| # add alpha channel to visualization | |
| visualization = np.concatenate([visualization, np.ones((image_height, image_width, 1), dtype=np.uint8) * 255], axis=-1) | |
| plt_image = Image.fromarray(visualization, mode="RGBA") | |
| draw = ImageDraw.Draw(plt_image) | |
| draw.rectangle((5, 5, 150, 30), fill=(10, 10, 10, 100)) | |
| animal = predicted_animal.capitalize() | |
| if target_lang is not None and target_lang != "None": | |
| animal += f' ({get_translated(animal, target_lang)})' | |
| draw.text((10, 7), animal, font=font, fill=(255, 125, 0, 255)) | |
| visualization = np.array(plt_image) | |
| out_image = Image.fromarray(visualization) | |
| return out_image | |
| def gradcam_video(video, colormap="Jet", use_eigen_smooth=False, BWHighlight=False, alpha=0.5, cam_method=GradCAM, layer=None, specific_class="Predicted Class", label_image=True, target_lang="German"): | |
| global OUTPUT_FPS, MAXIMAL_FRAMES, BATCHES_TO_PROCESS, MAX_OUT_FRAMES | |
| if video is None: raise gr.Error("Please upload a video.") | |
| if colormap not in CV2_COLORMAPS.keys(): | |
| raise gr.Error(f"Colormap {colormap} not found in {list(CV2_COLORMAPS.keys())}.") | |
| else: | |
| colormap = CV2_COLORMAPS[colormap] | |
| video = cv2.VideoCapture(video) | |
| fps = int(video.get(cv2.CAP_PROP_FPS)) | |
| if OUTPUT_FPS == -1: OUTPUT_FPS = fps | |
| width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| if width > 2000 or height > 2000: | |
| raise gr.Error("The video is too big. The maximal size is 2000x2000.") | |
| print(f'FPS: {fps}, Width: {width}, Height: {height}') | |
| frames = list() | |
| success, image = video.read() | |
| while success: | |
| frames.append(image) | |
| success, image = video.read() | |
| print(f'Frames: {len(frames)}') | |
| if len(frames) == 0: | |
| raise gr.Error("The video is empty.") | |
| if len(frames) >= MAXIMAL_FRAMES: | |
| raise gr.Error(f"The video is too long. The maximal length is {MAXIMAL_FRAMES} frames.") | |
| if len(frames) > MAX_OUT_FRAMES: | |
| frames = frames[::len(frames) // MAX_OUT_FRAMES] | |
| print(f'Frames to process: {len(frames)}') | |
| processed = [Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) for frame in frames] | |
| # generate lists in lists for the images for batch processing. BATCHES_TO_PROCESS images per inner list | |
| batched = [processed[i:i + BATCHES_TO_PROCESS] for i in range(0, len(processed), BATCHES_TO_PROCESS)] | |
| MODEL.eval() | |
| layers = LAYERS[layer] | |
| results = list() | |
| targets = [ClassifierOutputTarget(get_class_idx(specific_class))] if specific_class != "Predicted Class" else None | |
| with CAM_METHODS[cam_method](model=MODEL, target_layers=layers) as cam: | |
| for i, batch in enumerate(tqdm(batched)): | |
| images_tensor = torch.stack([transform(image) for image in batch]) | |
| grayscale_cam = cam(input_tensor=images_tensor, targets=targets, aug_smooth=False, eigen_smooth=use_eigen_smooth) | |
| for i, image in enumerate(batch): | |
| _grayscale_cam = grayscale_cam[i, :] | |
| _grayscale_cam = cv2.resize(_grayscale_cam, (width, height), interpolation=cv2.INTER_LINEAR) | |
| image = np.float32(image) | |
| visualization = None | |
| if BWHighlight: | |
| image = image * _grayscale_cam[..., np.newaxis] | |
| visualization = image.astype(np.uint8) | |
| else: | |
| image = image / 255 | |
| visualization = show_cam_on_image(image, _grayscale_cam, use_rgb=True, image_weight=alpha, colormap=colormap) | |
| if label_image: | |
| pass | |
| predicted_animal = get_class_name(np.argmax(MODEL.output.cpu().data.numpy(), axis=-1)[i]) | |
| plt_image = Image.fromarray(visualization, mode="RGB") | |
| draw = ImageDraw.Draw(plt_image) | |
| draw.rectangle((5, 5, 150, 30), fill=(10, 10, 10, 100)) | |
| animal = predicted_animal.capitalize() | |
| if target_lang is not None and target_lang != "None": | |
| animal += f' ({get_translated(animal, target_lang)})' | |
| draw.text((10, 7), animal, font=font, fill=(255, 125, 0, 255)) | |
| visualization = np.array(plt_image) | |
| results.append(visualization) | |
| # save video | |
| mediapy.write_video('src/results/gradcam_video.mp4', results, fps=OUTPUT_FPS) | |
| video.release() | |
| return 'src/results/gradcam_video.mp4' | |
| def load_examples(): | |
| folder_name_to_header = { | |
| "AI_Generated": "AI Generated Images", | |
| "true": "True Predicted Images (Validation Set)", | |
| "false": "False Predicted Images (Validation Set)", | |
| "others": "Other interesting images from the internet" | |
| } | |
| images_description = { | |
| "AI_Generated": "These images are generated by Dalle3 and Stable Diffusion. All of them are not real images and because of that it is interesting to see how the model predicts them.", | |
| "true": "These images are from the validation set and the model predicted them correctly.", | |
| "false": "These images are from the validation set and the model predicted them incorrectly. Maybe you can see why the model predicted them incorrectly using the GradCAM visualization. :)", | |
| "others": "These images are from the internet and are not part of the validation set. They are interesting because most of them show different animals." | |
| } | |
| loaded_images = defaultdict(list) | |
| for image_type in ["AI_Generated", "true", "false", "others"]: | |
| # for image_type in os.listdir(IMAGE_PATH): | |
| full_path = os.path.join(IMAGE_PATH, image_type).replace('\\', '/').replace('//', '/') | |
| gr.Markdown(f'## {folder_name_to_header[image_type]}') | |
| gr.Markdown(images_description[image_type]) | |
| images_to_load = os.listdir(full_path) | |
| rows = (len(images_to_load) // IMAGES_PER_ROW) + 1 | |
| for i in range(rows): | |
| with gr.Row(elem_classes=["row-example-images"], equal_height=False): | |
| for j in range(IMAGES_PER_ROW): | |
| if i * IMAGES_PER_ROW + j >= len(images_to_load): break | |
| image = images_to_load[i * IMAGES_PER_ROW + j] | |
| name = f"{image.split('.')[0]}" | |
| image = Image.open(os.path.join(full_path, image)) | |
| # scale so that the longest side is 600px | |
| scale = 600 / max(image.size) | |
| image = image.resize((int(image.size[0] * scale), int(image.size[1] * scale))) | |
| loaded_images[image_type].append( | |
| gr.Image( | |
| value=image, | |
| label=name, | |
| type="pil", | |
| interactive=False, | |
| elem_classes=["selectable_images"], | |
| ) | |
| ) | |
| return loaded_images | |
| css = """ | |
| #logo {text-align: right;} | |
| p {text-align: justify; text-justify: inter-word; font-size: 1.1em; line-height: 1.2em;} | |
| .svelte-1btp92j.selectable {cursor: pointer !important; } | |
| """ | |
| with gr.Blocks(theme='freddyaboulton/dracula_revamped', css=css) as demo: | |
| # ------------------------------------------- | |
| # HEADER WITH LOGO | |
| # ------------------------------------------- | |
| with gr.Row(): | |
| with open('src/header.md', 'r', encoding='utf-8') as f: | |
| markdown_string = f.read() | |
| with gr.Column(scale=10): | |
| header = gr.Markdown(markdown_string) | |
| with gr.Column(scale=1): | |
| pil_logo = Image.open('animals.png') | |
| logo = gr.Image(value=pil_logo, scale=2, interactive=False, show_download_button=False, show_label=False, container=False, elem_id="logo") | |
| animal_translation_target_language = gr.Dropdown( | |
| choices=LANGUAGES_TO_SELECT.keys(), | |
| label="Translation language for animals", | |
| value="German", | |
| interactive=True, | |
| scale=2, | |
| ) | |
| # ------------------------------------------- | |
| # INPUT IMAGE | |
| # ------------------------------------------- | |
| with gr.Row(): | |
| with gr.Row(variant="panel", equal_height=True): | |
| user_image = gr.Image( | |
| type="pil", | |
| label="Upload Your Own Image", | |
| interactive=True, | |
| ) | |
| # ------------------------------------------- | |
| # TOOLS | |
| # ------------------------------------------- | |
| with gr.Row(): | |
| # ------------------------------------------- | |
| # PREDICT | |
| # ------------------------------------------- | |
| with gr.Tab("Predict"): | |
| with gr.Column(): | |
| output = gr.Label( | |
| num_top_classes=5, | |
| label="Output", | |
| info="Top three predicted classes and their confidences.", | |
| scale=5, | |
| ) | |
| with gr.Row(): | |
| predict_mode_button = gr.Button(value="Predict Animal", label="Predict", info="Click to make a prediction.", scale=6) | |
| predict_mode_button.click(fn=infer_image, inputs=[user_image, animal_translation_target_language], outputs=output, queue=True) | |
| # ------------------------------------------- | |
| # EXPLAIN | |
| # ------------------------------------------- | |
| with gr.Tab("Explain Image"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| _info = "There are different GradCAM methods. You can read more about them here: (https://github.com/jacobgil/pytorch-grad-cam#references)." | |
| cam_method = gr.Radio( | |
| list(CAM_METHODS.keys()), | |
| label="GradCAM Method", | |
| info=_info, | |
| value="GradCAM", | |
| interactive=True, | |
| scale=2, | |
| ) | |
| _info = """ | |
| The alpha value is used to blend the original image with the GradCAM visualization. If you choose a value of 0.5 the original image and the GradCAM visualization will be blended equally. | |
| If you choose a value of 0.1 the original image will be barely visible and if you choose a value of 0.9 the GradCAM visualization will be barely visible. | |
| """ | |
| alpha = gr.Slider( | |
| minimum=.1, | |
| maximum=.9, | |
| value=0.5, | |
| interactive=True, | |
| step=.1, | |
| label="Alpha", | |
| scale=1, | |
| info=_info | |
| ) | |
| _info = """ | |
| The layer is used to choose the layer of the ResNet50 model. The GradCAM visualization will be based on this layer. | |
| Best to choose is the last layer (layer4) because it is the layer with the most information before the final prediction. This makes the GradCAM visualization the most meaningful. | |
| If all layers are chosen the GradCAM visualization will be averaged over all layers. | |
| """ | |
| layer = gr.Radio( | |
| LAYERS.keys(), | |
| label="Layer", | |
| value="layer4", | |
| interactive=True, | |
| scale=2, | |
| info=_info | |
| ) | |
| with gr.Row(): | |
| _info = """ | |
| Here you can choose the animal to "explain". If you choose "Predicted Class" the GradCAM visualization will be based on the predicted class. | |
| If you choose a specific class the GradCAM visualization will be based on this class. | |
| For example if you have an image with a dog and a cat, you can select either Cat or Dog and see if the model can focus on the correct animal. | |
| """ | |
| animal_to_explain = gr.Dropdown( | |
| choices=["Predicted Class"] + ALL_CLASSES, | |
| label="Animal", | |
| value="Predicted Class", | |
| interactive=True, | |
| scale=4, | |
| info=_info | |
| ) | |
| show_predicted_class = gr.Checkbox( | |
| label="Show Predicted Class", | |
| value=True, | |
| interactive=True, | |
| scale=1, | |
| ) | |
| with gr.Row(): | |
| _info = """ | |
| Here you can choose the colormap. Instead of a colormap you can also choose "BW Highlight" to just keep the original image and highlight the important parts of the image. | |
| If you select "BW Highlight" the colormap will be ignored. | |
| """ | |
| colormap = gr.Dropdown( | |
| choices=list(CV2_COLORMAPS.keys()), | |
| label="Colormap", | |
| value="Inferno", | |
| interactive=True, | |
| scale=2, | |
| info=_info | |
| ) | |
| bw_highlight = gr.Checkbox( | |
| label="BW Highlight", | |
| value=False, | |
| interactive=True, | |
| scale=1, | |
| ) | |
| bw_highlight.description = "Here you can choose if you want to highlight the important parts of the image in black and white." | |
| with gr.Row(): | |
| _info = """ | |
| The Eigen Smooth is a method to smooth the GradCAM visualization. | |
| """ | |
| use_eigen_smooth = gr.Checkbox( | |
| label="Eigen Smooth", | |
| value=False, | |
| interactive=True, | |
| scale=1, | |
| info=_info | |
| ) | |
| _info = """ | |
| The Aug Smooth is also a method to smooth the GradCAM visualization. But this method needs a lot of performance and is therefore slow. | |
| """ | |
| use_aug_smooth = gr.Checkbox( | |
| label="Aug Smooth", | |
| value=False, | |
| interactive=True, | |
| scale=1, | |
| info=_info | |
| ) | |
| with gr.Column(): | |
| gradcam_mode_button = gr.Button(value="Show GradCAM", label="GradCAM", info="Click to make a prediction.", scale=1) | |
| output_cam = gr.Image( | |
| type="pil", | |
| label="GradCAM", | |
| info="GradCAM visualization", | |
| show_label=False, | |
| scale=7, | |
| ) | |
| _inputs = [user_image, colormap, use_eigen_smooth, use_aug_smooth, bw_highlight, alpha, cam_method, layer, animal_to_explain, show_predicted_class, animal_translation_target_language] | |
| gradcam_mode_button.click(fn=gradcam, inputs=_inputs, outputs=output_cam, queue=True) | |
| # ------------------------------------------- | |
| # Video CAM | |
| # ------------------------------------------- | |
| with gr.Tab("Explain Video"): | |
| build_video_to_camvideo(CAM_METHODS, CV2_COLORMAPS, LAYERS, ALL_CLASSES, gradcam_video, animal_translation_target_language) | |
| # ------------------------------------------- | |
| # EXAMPLES | |
| # ------------------------------------------- | |
| with gr.Tab("Example Images"): | |
| placeholder = gr.Markdown("## Example Images") | |
| loaded_images = load_examples() | |
| for k in loaded_images.keys(): | |
| for image in loaded_images[k]: | |
| image.select(fn=lambda x: x, inputs=[image], outputs=[user_image], queue=True, scroll_to_output=True) | |
| if __name__ == "__main__": | |
| demo.queue() | |
| print("Starting Gradio server...") | |
| demo.launch(show_tips=True) |