Spaces:
Sleeping
Sleeping
Add new classes and features
Browse files- app.py +109 -57
- src/Nets.py +1 -38
- src/Roboto-Regular.ttf +0 -0
- src/cache/val_df.csv +0 -0
- src/examples/{false_predicted/squirrel.jpg → false/bee.jpg} +2 -2
- src/examples/{false_predicted/chimpanzee.jpg → false/coyote.jpg} +2 -2
- src/examples/{true_predicted/cat.jpg → false/donkey.jpg} +2 -2
- src/examples/false/goat.jpg +3 -0
- src/examples/false/hornbill.jpg +3 -0
- src/examples/false_predicted/starfish.jpg +0 -3
- src/examples/true/dolphin.jpg +3 -0
- src/examples/true/dragonfly.jpg +3 -0
- src/examples/{false_predicted → true}/koala.jpg +2 -2
- src/examples/{false_predicted → true}/sheep.jpg +2 -2
- src/examples/true/squid.jpg +3 -0
- src/examples/true_predicted/cockroach.jpg +0 -3
- src/examples/true_predicted/flamingo.jpg +0 -3
- src/examples/true_predicted/gorilla.jpg +0 -3
- src/examples/true_predicted/grasshopper.jpg +0 -3
- src/gradio_blocks.py +2 -2
- src/header.md +2 -2
- src/results/gradcam_video.mp4 +2 -2
- src/results/infer_image.png +2 -2
- src/results/models/best_model.pth +2 -2
- src/util.py +1 -75
app.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import copy
|
| 2 |
import os
|
| 3 |
import sys
|
|
@@ -16,31 +17,43 @@ import torch
|
|
| 16 |
from deep_translator import GoogleTranslator
|
| 17 |
from gradio_blocks import build_video_to_camvideo
|
| 18 |
from Nets import CustomResNet18
|
| 19 |
-
from PIL import Image
|
| 20 |
|
| 21 |
from pytorch_grad_cam import GradCAM, HiResCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
|
| 22 |
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
|
| 23 |
from pytorch_grad_cam.utils.image import show_cam_on_image
|
| 24 |
|
| 25 |
from tqdm import tqdm
|
| 26 |
-
import
|
| 27 |
-
|
|
|
|
| 28 |
|
| 29 |
-
util.ImageCache = CustomImageCache(60, False)
|
| 30 |
ffmpeg_path = shutil.which('ffmpeg')
|
| 31 |
mediapy.set_ffmpeg(ffmpeg_path)
|
| 32 |
|
| 33 |
IMAGE_PATH = os.path.join(os.getcwd(), 'src/examples')
|
| 34 |
IMAGES_PER_ROW = 5
|
| 35 |
|
| 36 |
-
MAXIMAL_FRAMES =
|
| 37 |
-
BATCHES_TO_PROCESS =
|
| 38 |
OUTPUT_FPS = 10
|
| 39 |
-
MAX_OUT_FRAMES =
|
| 40 |
|
| 41 |
-
MODEL = CustomResNet18(
|
| 42 |
MODEL.load_state_dict(torch.load('src/results/models/best_model.pth', map_location=torch.device('cpu')))
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
CAM_METHODS = {
|
| 45 |
"GradCAM": GradCAM,
|
| 46 |
"GradCAM++": GradCAMPlusPlus,
|
|
@@ -87,16 +100,21 @@ def get_class_name(idx):
|
|
| 87 |
def get_class_idx(name):
|
| 88 |
return C_NAME_TO_NUM[name]
|
| 89 |
|
| 90 |
-
@lru_cache(maxsize=
|
| 91 |
-
def get_translated(to_translate):
|
| 92 |
-
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
if isinstance(image, dict):
|
| 98 |
-
# Its the image and a mask as pillow both -> Combine them to one image
|
| 99 |
-
image = Image.blend(image["image"], image["mask"], alpha=0.5)
|
| 100 |
image.save('src/results/infer_image.png')
|
| 101 |
image = transform(image)
|
| 102 |
image = image.unsqueeze(0)
|
|
@@ -105,11 +123,13 @@ def infer_image(image):
|
|
| 105 |
distribution = torch.nn.functional.softmax(output, dim=1)
|
| 106 |
ret = defaultdict(float)
|
| 107 |
for idx, prob in enumerate(distribution[0]):
|
| 108 |
-
animal = f'{get_class_name(idx)}
|
|
|
|
|
|
|
| 109 |
ret[animal] = prob.item()
|
| 110 |
return ret
|
| 111 |
|
| 112 |
-
def gradcam(image, colormap="Jet", use_eigen_smooth=False, use_aug_smooth=False, BWHighlight=False, alpha=0.5, cam_method=GradCAM, layer=None, specific_class="Predicted Class"):
|
| 113 |
if image is None:
|
| 114 |
raise gr.Error("Please upload an image.")
|
| 115 |
|
|
@@ -123,8 +143,8 @@ def gradcam(image, colormap="Jet", use_eigen_smooth=False, use_aug_smooth=False,
|
|
| 123 |
colormap = CV2_COLORMAPS[colormap]
|
| 124 |
|
| 125 |
image_width, image_height = image.size
|
| 126 |
-
if image_width >
|
| 127 |
-
raise gr.Error("The image is too big. The maximal size is
|
| 128 |
|
| 129 |
|
| 130 |
MODEL.eval()
|
|
@@ -135,6 +155,8 @@ def gradcam(image, colormap="Jet", use_eigen_smooth=False, use_aug_smooth=False,
|
|
| 135 |
|
| 136 |
with CAM_METHODS[cam_method](model=MODEL, target_layers=layers) as cam:
|
| 137 |
grayscale_cam = cam(input_tensor=image_tensor, targets=targets, aug_smooth=use_aug_smooth, eigen_smooth=use_eigen_smooth)
|
|
|
|
|
|
|
| 138 |
|
| 139 |
grayscale_cam = grayscale_cam[0, :]
|
| 140 |
grayscale_cam = cv2.resize(grayscale_cam, (image_width, image_height), interpolation=cv2.INTER_CUBIC)
|
|
@@ -146,10 +168,25 @@ def gradcam(image, colormap="Jet", use_eigen_smooth=False, use_aug_smooth=False,
|
|
| 146 |
else:
|
| 147 |
image = image / 255
|
| 148 |
visualization = show_cam_on_image(image, grayscale_cam, use_rgb=True, image_weight=alpha, colormap=colormap)
|
| 149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
|
| 151 |
def gradcam_video(video, colormap="Jet", use_eigen_smooth=False, BWHighlight=False, alpha=0.5, cam_method=GradCAM, layer=None, specific_class="Predicted Class"):
|
| 152 |
global OUTPUT_FPS, MAXIMAL_FRAMES, BATCHES_TO_PROCESS, MAX_OUT_FRAMES
|
|
|
|
| 153 |
if colormap not in CV2_COLORMAPS.keys():
|
| 154 |
raise gr.Error(f"Colormap {colormap} not found in {list(CV2_COLORMAPS.keys())}.")
|
| 155 |
else:
|
|
@@ -159,8 +196,8 @@ def gradcam_video(video, colormap="Jet", use_eigen_smooth=False, BWHighlight=Fal
|
|
| 159 |
if OUTPUT_FPS == -1: OUTPUT_FPS = fps
|
| 160 |
width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 161 |
height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 162 |
-
if width >
|
| 163 |
-
raise gr.Error("The video is too big. The maximal size is
|
| 164 |
print(f'FPS: {fps}, Width: {width}, Height: {height}')
|
| 165 |
|
| 166 |
frames = list()
|
|
@@ -213,21 +250,21 @@ def gradcam_video(video, colormap="Jet", use_eigen_smooth=False, BWHighlight=Fal
|
|
| 213 |
def load_examples():
|
| 214 |
folder_name_to_header = {
|
| 215 |
"AI_Generated": "AI Generated Images",
|
| 216 |
-
"
|
| 217 |
-
"
|
| 218 |
"others": "Other interesting images from the internet"
|
| 219 |
}
|
| 220 |
|
| 221 |
images_description = {
|
| 222 |
"AI_Generated": "These images are generated by Dalle3 and Stable Diffusion. All of them are not real images and because of that it is interesting to see how the model predicts them.",
|
| 223 |
-
"
|
| 224 |
-
"
|
| 225 |
"others": "These images are from the internet and are not part of the validation set. They are interesting because most of them show different animals."
|
| 226 |
}
|
| 227 |
|
| 228 |
loaded_images = defaultdict(list)
|
| 229 |
|
| 230 |
-
for image_type in ["AI_Generated", "
|
| 231 |
# for image_type in os.listdir(IMAGE_PATH):
|
| 232 |
full_path = os.path.join(IMAGE_PATH, image_type).replace('\\', '/').replace('//', '/')
|
| 233 |
gr.Markdown(f'## {folder_name_to_header[image_type]}')
|
|
@@ -239,7 +276,7 @@ def load_examples():
|
|
| 239 |
for j in range(IMAGES_PER_ROW):
|
| 240 |
if i * IMAGES_PER_ROW + j >= len(images_to_load): break
|
| 241 |
image = images_to_load[i * IMAGES_PER_ROW + j]
|
| 242 |
-
name = f"{image.split('.')[0]}
|
| 243 |
image = Image.open(os.path.join(full_path, image))
|
| 244 |
# scale so that the longest side is 600px
|
| 245 |
scale = 600 / max(image.size)
|
|
@@ -273,7 +310,15 @@ with gr.Blocks(theme='freddyaboulton/dracula_revamped', css=css) as demo:
|
|
| 273 |
with gr.Column(scale=1):
|
| 274 |
pil_logo = Image.open('animals.png')
|
| 275 |
logo = gr.Image(value=pil_logo, scale=2, interactive=False, show_download_button=False, show_label=False, container=False, elem_id="logo")
|
| 276 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
# -------------------------------------------
|
| 278 |
# INPUT IMAGE
|
| 279 |
# -------------------------------------------
|
|
@@ -282,7 +327,6 @@ with gr.Blocks(theme='freddyaboulton/dracula_revamped', css=css) as demo:
|
|
| 282 |
user_image = gr.Image(
|
| 283 |
type="pil",
|
| 284 |
label="Upload Your Own Image",
|
| 285 |
-
tool="sketch",
|
| 286 |
interactive=True,
|
| 287 |
)
|
| 288 |
|
|
@@ -301,8 +345,9 @@ with gr.Blocks(theme='freddyaboulton/dracula_revamped', css=css) as demo:
|
|
| 301 |
info="Top three predicted classes and their confidences.",
|
| 302 |
scale=5,
|
| 303 |
)
|
| 304 |
-
|
| 305 |
-
|
|
|
|
| 306 |
|
| 307 |
# -------------------------------------------
|
| 308 |
# EXPLAIN
|
|
@@ -348,20 +393,28 @@ with gr.Blocks(theme='freddyaboulton/dracula_revamped', css=css) as demo:
|
|
| 348 |
scale=2,
|
| 349 |
info=_info
|
| 350 |
)
|
|
|
|
| 351 |
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 365 |
|
| 366 |
with gr.Row():
|
| 367 |
_info = """
|
|
@@ -371,7 +424,7 @@ with gr.Blocks(theme='freddyaboulton/dracula_revamped', css=css) as demo:
|
|
| 371 |
colormap = gr.Dropdown(
|
| 372 |
choices=list(CV2_COLORMAPS.keys()),
|
| 373 |
label="Colormap",
|
| 374 |
-
value="
|
| 375 |
interactive=True,
|
| 376 |
scale=2,
|
| 377 |
info=_info
|
|
@@ -410,15 +463,16 @@ with gr.Blocks(theme='freddyaboulton/dracula_revamped', css=css) as demo:
|
|
| 410 |
|
| 411 |
|
| 412 |
with gr.Column():
|
|
|
|
| 413 |
output_cam = gr.Image(
|
| 414 |
type="pil",
|
| 415 |
label="GradCAM",
|
| 416 |
info="GradCAM visualization",
|
| 417 |
-
|
|
|
|
| 418 |
)
|
| 419 |
-
|
| 420 |
-
gradcam_mode_button
|
| 421 |
-
gradcam_mode_button.click(fn=gradcam, inputs=[user_image, colormap, use_eigen_smooth, use_aug_smooth, bw_highlight, alpha, cam_method, layer, animal_to_explain], outputs=output_cam, queue=True)
|
| 422 |
|
| 423 |
# -------------------------------------------
|
| 424 |
# Video CAM
|
|
@@ -434,11 +488,9 @@ with gr.Blocks(theme='freddyaboulton/dracula_revamped', css=css) as demo:
|
|
| 434 |
loaded_images = load_examples()
|
| 435 |
for k in loaded_images.keys():
|
| 436 |
for image in loaded_images[k]:
|
| 437 |
-
image.select(fn=lambda x: x, inputs=[image], outputs=[user_image])
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
|
| 442 |
if __name__ == "__main__":
|
| 443 |
demo.queue()
|
| 444 |
-
|
|
|
|
|
|
| 1 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 2 |
import copy
|
| 3 |
import os
|
| 4 |
import sys
|
|
|
|
| 17 |
from deep_translator import GoogleTranslator
|
| 18 |
from gradio_blocks import build_video_to_camvideo
|
| 19 |
from Nets import CustomResNet18
|
| 20 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 21 |
|
| 22 |
from pytorch_grad_cam import GradCAM, HiResCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
|
| 23 |
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
|
| 24 |
from pytorch_grad_cam.utils.image import show_cam_on_image
|
| 25 |
|
| 26 |
from tqdm import tqdm
|
| 27 |
+
from util import transform
|
| 28 |
+
|
| 29 |
+
font = ImageFont.truetype("src/Roboto-Regular.ttf", 16)
|
| 30 |
|
|
|
|
| 31 |
ffmpeg_path = shutil.which('ffmpeg')
|
| 32 |
mediapy.set_ffmpeg(ffmpeg_path)
|
| 33 |
|
| 34 |
IMAGE_PATH = os.path.join(os.getcwd(), 'src/examples')
|
| 35 |
IMAGES_PER_ROW = 5
|
| 36 |
|
| 37 |
+
MAXIMAL_FRAMES = 700
|
| 38 |
+
BATCHES_TO_PROCESS = 20
|
| 39 |
OUTPUT_FPS = 10
|
| 40 |
+
MAX_OUT_FRAMES = 70
|
| 41 |
|
| 42 |
+
MODEL = CustomResNet18(111).eval()
|
| 43 |
MODEL.load_state_dict(torch.load('src/results/models/best_model.pth', map_location=torch.device('cpu')))
|
| 44 |
|
| 45 |
+
LANGUAGES_TO_SELECT = {
|
| 46 |
+
"None": None,
|
| 47 |
+
"German": "de",
|
| 48 |
+
"French": "fr",
|
| 49 |
+
"Spanish": "es",
|
| 50 |
+
"Italian": "it",
|
| 51 |
+
"Finnish": "fi",
|
| 52 |
+
"Ukrainian": "uk",
|
| 53 |
+
"Japanese": "ja",
|
| 54 |
+
"Hebrew": "iw"
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
CAM_METHODS = {
|
| 58 |
"GradCAM": GradCAM,
|
| 59 |
"GradCAM++": GradCAMPlusPlus,
|
|
|
|
| 100 |
def get_class_idx(name):
|
| 101 |
return C_NAME_TO_NUM[name]
|
| 102 |
|
| 103 |
+
@lru_cache(maxsize=len(LANGUAGES_TO_SELECT.keys())*111)
|
| 104 |
+
def get_translated(to_translate, target_language="German"):
|
| 105 |
+
target_language = LANGUAGES_TO_SELECT[target_language] if target_language in LANGUAGES_TO_SELECT else target_language
|
| 106 |
+
if target_language == "en": return to_translate
|
| 107 |
+
if target_language not in LANGUAGES_TO_SELECT.values(): raise gr.Error(f'Language {target_language} not found.')
|
| 108 |
+
return GoogleTranslator(source="en", target=target_language).translate(to_translate)
|
| 109 |
+
# for idx in range(111): get_translated(get_class_name(idx))
|
| 110 |
+
with ThreadPoolExecutor(max_workers=30) as executor:
|
| 111 |
+
# give the executor the list of images and args (in this case, the target language)
|
| 112 |
+
# and let the executor map the function to the list of images
|
| 113 |
+
for language in tqdm(LANGUAGES_TO_SELECT.keys(), desc='Preloading translations'):
|
| 114 |
+
executor.map(get_translated, ALL_CLASSES, [language] * len(ALL_CLASSES))
|
| 115 |
|
| 116 |
+
def infer_image(image, target_language):
|
| 117 |
+
if image is None: raise gr.Error("Please upload an image.")
|
|
|
|
|
|
|
|
|
|
| 118 |
image.save('src/results/infer_image.png')
|
| 119 |
image = transform(image)
|
| 120 |
image = image.unsqueeze(0)
|
|
|
|
| 123 |
distribution = torch.nn.functional.softmax(output, dim=1)
|
| 124 |
ret = defaultdict(float)
|
| 125 |
for idx, prob in enumerate(distribution[0]):
|
| 126 |
+
animal = f'{get_class_name(idx)}'
|
| 127 |
+
if target_language is not None and target_language != "None":
|
| 128 |
+
animal += f' ({get_translated(get_class_name(idx), target_language)})'
|
| 129 |
ret[animal] = prob.item()
|
| 130 |
return ret
|
| 131 |
|
| 132 |
+
def gradcam(image, colormap="Jet", use_eigen_smooth=False, use_aug_smooth=False, BWHighlight=False, alpha=0.5, cam_method=GradCAM, layer=None, specific_class="Predicted Class", label_image=True, target_lang="German"):
|
| 133 |
if image is None:
|
| 134 |
raise gr.Error("Please upload an image.")
|
| 135 |
|
|
|
|
| 143 |
colormap = CV2_COLORMAPS[colormap]
|
| 144 |
|
| 145 |
image_width, image_height = image.size
|
| 146 |
+
if image_width > 6000 or image_height > 6000:
|
| 147 |
+
raise gr.Error("The image is too big. The maximal size is 6000x6000.")
|
| 148 |
|
| 149 |
|
| 150 |
MODEL.eval()
|
|
|
|
| 155 |
|
| 156 |
with CAM_METHODS[cam_method](model=MODEL, target_layers=layers) as cam:
|
| 157 |
grayscale_cam = cam(input_tensor=image_tensor, targets=targets, aug_smooth=use_aug_smooth, eigen_smooth=use_eigen_smooth)
|
| 158 |
+
if label_image:
|
| 159 |
+
predicted_animal = get_class_name(np.argmax(cam.outputs.cpu().data.numpy(), axis=-1)[0])
|
| 160 |
|
| 161 |
grayscale_cam = grayscale_cam[0, :]
|
| 162 |
grayscale_cam = cv2.resize(grayscale_cam, (image_width, image_height), interpolation=cv2.INTER_CUBIC)
|
|
|
|
| 168 |
else:
|
| 169 |
image = image / 255
|
| 170 |
visualization = show_cam_on_image(image, grayscale_cam, use_rgb=True, image_weight=alpha, colormap=colormap)
|
| 171 |
+
|
| 172 |
+
if label_image:
|
| 173 |
+
# add alpha channel to visualization
|
| 174 |
+
visualization = np.concatenate([visualization, np.ones((image_height, image_width, 1), dtype=np.uint8) * 255], axis=-1)
|
| 175 |
+
plt_image = Image.fromarray(visualization, mode="RGBA")
|
| 176 |
+
draw = ImageDraw.Draw(plt_image)
|
| 177 |
+
draw.rectangle((5, 5, 150, 30), fill=(10, 10, 10, 100))
|
| 178 |
+
animal = predicted_animal.capitalize()
|
| 179 |
+
if target_lang is not None and target_lang != "None":
|
| 180 |
+
animal += f' ({get_translated(animal, target_lang)})'
|
| 181 |
+
draw.text((10, 7), animal, font=font, fill=(255, 125, 0, 255))
|
| 182 |
+
visualization = np.array(plt_image)
|
| 183 |
+
|
| 184 |
+
out_image = Image.fromarray(visualization)
|
| 185 |
+
return out_image
|
| 186 |
|
| 187 |
def gradcam_video(video, colormap="Jet", use_eigen_smooth=False, BWHighlight=False, alpha=0.5, cam_method=GradCAM, layer=None, specific_class="Predicted Class"):
|
| 188 |
global OUTPUT_FPS, MAXIMAL_FRAMES, BATCHES_TO_PROCESS, MAX_OUT_FRAMES
|
| 189 |
+
if video is None: raise gr.Error("Please upload a video.")
|
| 190 |
if colormap not in CV2_COLORMAPS.keys():
|
| 191 |
raise gr.Error(f"Colormap {colormap} not found in {list(CV2_COLORMAPS.keys())}.")
|
| 192 |
else:
|
|
|
|
| 196 |
if OUTPUT_FPS == -1: OUTPUT_FPS = fps
|
| 197 |
width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 198 |
height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 199 |
+
if width > 2000 or height > 2000:
|
| 200 |
+
raise gr.Error("The video is too big. The maximal size is 2000x2000.")
|
| 201 |
print(f'FPS: {fps}, Width: {width}, Height: {height}')
|
| 202 |
|
| 203 |
frames = list()
|
|
|
|
| 250 |
def load_examples():
|
| 251 |
folder_name_to_header = {
|
| 252 |
"AI_Generated": "AI Generated Images",
|
| 253 |
+
"true": "True Predicted Images (Validation Set)",
|
| 254 |
+
"false": "False Predicted Images (Validation Set)",
|
| 255 |
"others": "Other interesting images from the internet"
|
| 256 |
}
|
| 257 |
|
| 258 |
images_description = {
|
| 259 |
"AI_Generated": "These images are generated by Dalle3 and Stable Diffusion. All of them are not real images and because of that it is interesting to see how the model predicts them.",
|
| 260 |
+
"true": "These images are from the validation set and the model predicted them correctly.",
|
| 261 |
+
"false": "These images are from the validation set and the model predicted them incorrectly. Maybe you can see why the model predicted them incorrectly using the GradCAM visualization. :)",
|
| 262 |
"others": "These images are from the internet and are not part of the validation set. They are interesting because most of them show different animals."
|
| 263 |
}
|
| 264 |
|
| 265 |
loaded_images = defaultdict(list)
|
| 266 |
|
| 267 |
+
for image_type in ["AI_Generated", "true", "false", "others"]:
|
| 268 |
# for image_type in os.listdir(IMAGE_PATH):
|
| 269 |
full_path = os.path.join(IMAGE_PATH, image_type).replace('\\', '/').replace('//', '/')
|
| 270 |
gr.Markdown(f'## {folder_name_to_header[image_type]}')
|
|
|
|
| 276 |
for j in range(IMAGES_PER_ROW):
|
| 277 |
if i * IMAGES_PER_ROW + j >= len(images_to_load): break
|
| 278 |
image = images_to_load[i * IMAGES_PER_ROW + j]
|
| 279 |
+
name = f"{image.split('.')[0]}"
|
| 280 |
image = Image.open(os.path.join(full_path, image))
|
| 281 |
# scale so that the longest side is 600px
|
| 282 |
scale = 600 / max(image.size)
|
|
|
|
| 310 |
with gr.Column(scale=1):
|
| 311 |
pil_logo = Image.open('animals.png')
|
| 312 |
logo = gr.Image(value=pil_logo, scale=2, interactive=False, show_download_button=False, show_label=False, container=False, elem_id="logo")
|
| 313 |
+
|
| 314 |
+
animal_translation_target_language = gr.Dropdown(
|
| 315 |
+
choices=LANGUAGES_TO_SELECT.keys(),
|
| 316 |
+
label="Translation language for animals",
|
| 317 |
+
value="German",
|
| 318 |
+
interactive=True,
|
| 319 |
+
scale=2,
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
# -------------------------------------------
|
| 323 |
# INPUT IMAGE
|
| 324 |
# -------------------------------------------
|
|
|
|
| 327 |
user_image = gr.Image(
|
| 328 |
type="pil",
|
| 329 |
label="Upload Your Own Image",
|
|
|
|
| 330 |
interactive=True,
|
| 331 |
)
|
| 332 |
|
|
|
|
| 345 |
info="Top three predicted classes and their confidences.",
|
| 346 |
scale=5,
|
| 347 |
)
|
| 348 |
+
with gr.Row():
|
| 349 |
+
predict_mode_button = gr.Button(value="Predict Animal", label="Predict", info="Click to make a prediction.", scale=6)
|
| 350 |
+
predict_mode_button.click(fn=infer_image, inputs=[user_image, animal_translation_target_language], outputs=output, queue=True)
|
| 351 |
|
| 352 |
# -------------------------------------------
|
| 353 |
# EXPLAIN
|
|
|
|
| 393 |
scale=2,
|
| 394 |
info=_info
|
| 395 |
)
|
| 396 |
+
with gr.Row():
|
| 397 |
|
| 398 |
+
_info = """
|
| 399 |
+
Here you can choose the animal to "explain". If you choose "Predicted Class" the GradCAM visualization will be based on the predicted class.
|
| 400 |
+
If you choose a specific class the GradCAM visualization will be based on this class.
|
| 401 |
+
For example if you have an image with a dog and a cat, you can select either Cat or Dog and see if the model can focus on the correct animal.
|
| 402 |
+
"""
|
| 403 |
+
animal_to_explain = gr.Dropdown(
|
| 404 |
+
choices=["Predicted Class"] + ALL_CLASSES,
|
| 405 |
+
label="Animal",
|
| 406 |
+
value="Predicted Class",
|
| 407 |
+
interactive=True,
|
| 408 |
+
scale=4,
|
| 409 |
+
info=_info
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
show_predicted_class = gr.Checkbox(
|
| 413 |
+
label="Show Predicted Class",
|
| 414 |
+
value=True,
|
| 415 |
+
interactive=True,
|
| 416 |
+
scale=1,
|
| 417 |
+
)
|
| 418 |
|
| 419 |
with gr.Row():
|
| 420 |
_info = """
|
|
|
|
| 424 |
colormap = gr.Dropdown(
|
| 425 |
choices=list(CV2_COLORMAPS.keys()),
|
| 426 |
label="Colormap",
|
| 427 |
+
value="Inferno",
|
| 428 |
interactive=True,
|
| 429 |
scale=2,
|
| 430 |
info=_info
|
|
|
|
| 463 |
|
| 464 |
|
| 465 |
with gr.Column():
|
| 466 |
+
gradcam_mode_button = gr.Button(value="Show GradCAM", label="GradCAM", info="Click to make a prediction.", scale=1)
|
| 467 |
output_cam = gr.Image(
|
| 468 |
type="pil",
|
| 469 |
label="GradCAM",
|
| 470 |
info="GradCAM visualization",
|
| 471 |
+
show_label=False,
|
| 472 |
+
scale=7,
|
| 473 |
)
|
| 474 |
+
_inputs = [user_image, colormap, use_eigen_smooth, use_aug_smooth, bw_highlight, alpha, cam_method, layer, animal_to_explain, show_predicted_class, animal_translation_target_language]
|
| 475 |
+
gradcam_mode_button.click(fn=gradcam, inputs=_inputs, outputs=output_cam, queue=True)
|
|
|
|
| 476 |
|
| 477 |
# -------------------------------------------
|
| 478 |
# Video CAM
|
|
|
|
| 488 |
loaded_images = load_examples()
|
| 489 |
for k in loaded_images.keys():
|
| 490 |
for image in loaded_images[k]:
|
| 491 |
+
image.select(fn=lambda x: x, inputs=[image], outputs=[user_image], queue=True, scroll_to_output=True)
|
|
|
|
|
|
|
|
|
|
| 492 |
|
| 493 |
if __name__ == "__main__":
|
| 494 |
demo.queue()
|
| 495 |
+
print("Starting Gradio server...")
|
| 496 |
+
demo.launch(show_tips=True)
|
src/Nets.py
CHANGED
|
@@ -1,47 +1,10 @@
|
|
| 1 |
-
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
-
import torch.nn.functional as F
|
| 4 |
from torchvision import models
|
| 5 |
|
| 6 |
-
class SimpleCNN(nn.Module):
|
| 7 |
-
def __init__(self, k_size=3, pool_size=2, num_classes=1):
|
| 8 |
-
super(SimpleCNN, self).__init__()
|
| 9 |
-
self.relu = nn.ReLU()
|
| 10 |
-
# First Convolutional Layer
|
| 11 |
-
self.conv1 = nn.Conv2d(in_channels=3, out_channels=8, kernel_size=k_size, padding=1)
|
| 12 |
-
self.conv2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=k_size, stride=1, padding=1)
|
| 13 |
-
self.pool1 = nn.MaxPool2d(kernel_size=pool_size)
|
| 14 |
-
|
| 15 |
-
# Second Convolutional Layer
|
| 16 |
-
self.conv3 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=k_size, stride=1, padding=1)
|
| 17 |
-
self.conv4 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=k_size, stride=1, padding=1)
|
| 18 |
-
self.pool2 = nn.MaxPool2d(kernel_size=pool_size)
|
| 19 |
-
|
| 20 |
-
self.conv5 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=k_size, stride=1, padding=1)
|
| 21 |
-
self.conv6 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=k_size, stride=1, padding=1)
|
| 22 |
-
self.pool3 = nn.MaxPool2d(kernel_size=pool_size)
|
| 23 |
-
|
| 24 |
-
self.conv7 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=k_size, stride=1, padding=1)
|
| 25 |
-
self.conv8 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=k_size, stride=1, padding=1)
|
| 26 |
-
self.pool4 = nn.MaxPool2d(kernel_size=pool_size)
|
| 27 |
-
|
| 28 |
-
# Fully Connected Layers
|
| 29 |
-
self.fc = nn.Linear(64*14*14, num_classes) # Adjust the input features based on your input image size
|
| 30 |
-
|
| 31 |
-
def forward(self, x):
|
| 32 |
-
x = self.pool1(self.relu(self.conv2(self.relu(self.conv1(x)))))
|
| 33 |
-
x = self.pool2(self.relu(self.conv4(self.relu(self.conv3(x)))))
|
| 34 |
-
x = self.pool3(self.relu(self.conv6(self.relu(self.conv5(x)))))
|
| 35 |
-
x = self.pool4(self.relu(self.conv8(self.relu(self.conv7(x)))))
|
| 36 |
-
# print(x.shape)
|
| 37 |
-
x = x.view(x.size(0), -1)
|
| 38 |
-
x = self.fc(x)
|
| 39 |
-
return x
|
| 40 |
-
|
| 41 |
class CustomResNet18(nn.Module):
|
| 42 |
def __init__(self, num_classes=11):
|
| 43 |
super(CustomResNet18, self).__init__()
|
| 44 |
-
self.resnet = models.
|
| 45 |
num_features = self.resnet.fc.in_features
|
| 46 |
self.resnet.fc = nn.Linear(num_features, num_classes)
|
| 47 |
|
|
|
|
|
|
|
| 1 |
import torch.nn as nn
|
|
|
|
| 2 |
from torchvision import models
|
| 3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
class CustomResNet18(nn.Module):
|
| 5 |
def __init__(self, num_classes=11):
|
| 6 |
super(CustomResNet18, self).__init__()
|
| 7 |
+
self.resnet = models.resnet18(pretrained=True)
|
| 8 |
num_features = self.resnet.fc.in_features
|
| 9 |
self.resnet.fc = nn.Linear(num_features, num_classes)
|
| 10 |
|
src/Roboto-Regular.ttf
ADDED
|
Binary file (515 kB). View file
|
|
|
src/cache/val_df.csv
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
src/examples/{false_predicted/squirrel.jpg → false/bee.jpg}
RENAMED
|
File without changes
|
src/examples/{false_predicted/chimpanzee.jpg → false/coyote.jpg}
RENAMED
|
File without changes
|
src/examples/{true_predicted/cat.jpg → false/donkey.jpg}
RENAMED
|
File without changes
|
src/examples/false/goat.jpg
ADDED
|
Git LFS Details
|
src/examples/false/hornbill.jpg
ADDED
|
Git LFS Details
|
src/examples/false_predicted/starfish.jpg
DELETED
Git LFS Details
|
src/examples/true/dolphin.jpg
ADDED
|
Git LFS Details
|
src/examples/true/dragonfly.jpg
ADDED
|
Git LFS Details
|
src/examples/{false_predicted → true}/koala.jpg
RENAMED
|
File without changes
|
src/examples/{false_predicted → true}/sheep.jpg
RENAMED
|
File without changes
|
src/examples/true/squid.jpg
ADDED
|
Git LFS Details
|
src/examples/true_predicted/cockroach.jpg
DELETED
Git LFS Details
|
src/examples/true_predicted/flamingo.jpg
DELETED
Git LFS Details
|
src/examples/true_predicted/gorilla.jpg
DELETED
Git LFS Details
|
src/examples/true_predicted/grasshopper.jpg
DELETED
Git LFS Details
|
src/gradio_blocks.py
CHANGED
|
@@ -29,7 +29,7 @@ def build_video_to_camvideo(CAM_METHODS, CV2_COLORMAPS, LAYERS, ALL_CLASSES, gra
|
|
| 29 |
)
|
| 30 |
|
| 31 |
video_layer = gr.Radio(
|
| 32 |
-
|
| 33 |
label="Layer",
|
| 34 |
value="layer4",
|
| 35 |
interactive=True,
|
|
@@ -48,7 +48,7 @@ def build_video_to_camvideo(CAM_METHODS, CV2_COLORMAPS, LAYERS, ALL_CLASSES, gra
|
|
| 48 |
colormap = gr.Dropdown(
|
| 49 |
choices=list(CV2_COLORMAPS.keys()),
|
| 50 |
label="Colormap",
|
| 51 |
-
value="
|
| 52 |
interactive=True,
|
| 53 |
scale=2,
|
| 54 |
)
|
|
|
|
| 29 |
)
|
| 30 |
|
| 31 |
video_layer = gr.Radio(
|
| 32 |
+
[f"layer{i}" for i in range(1, 5)],
|
| 33 |
label="Layer",
|
| 34 |
value="layer4",
|
| 35 |
interactive=True,
|
|
|
|
| 48 |
colormap = gr.Dropdown(
|
| 49 |
choices=list(CV2_COLORMAPS.keys()),
|
| 50 |
label="Colormap",
|
| 51 |
+
value="Inferno",
|
| 52 |
interactive=True,
|
| 53 |
scale=2,
|
| 54 |
)
|
src/header.md
CHANGED
|
@@ -2,9 +2,9 @@
|
|
| 2 |
|
| 3 |
This project was created by [Ilyesse](https://github.com/ilyii) and [Gabriel](https://github.com/Gabriel9753) as part of the Explainable Machine Learning module at the [University of Applied Sciences Karlsruhe](https://www.h-ka.de/).
|
| 4 |
|
| 5 |
-
The dataset used in this project is the [Animal Image Dataset](https://www.kaggle.com/datasets/iamsouravbanerjee/animal-image-dataset-90-different-animals) from Kaggle, comprising 90 different animal species that needed to be classified. We also added approx. 1000 AI generated images for all classes to get a more diverse dataset and also improve the performance of the model.
|
| 6 |
|
| 7 |
-
The employed model is
|
| 8 |
Translation of animal names by [deep-translator](https://pypi.org/project/deep-translator/).
|
| 9 |
|
| 10 |
## Usage 🦎
|
|
|
|
| 2 |
|
| 3 |
This project was created by [Ilyesse](https://github.com/ilyii) and [Gabriel](https://github.com/Gabriel9753) as part of the Explainable Machine Learning module at the [University of Applied Sciences Karlsruhe](https://www.h-ka.de/).
|
| 4 |
|
| 5 |
+
The dataset used in this project is the [Animal Image Dataset](https://www.kaggle.com/datasets/iamsouravbanerjee/animal-image-dataset-90-different-animals) from Kaggle, comprising 90 different animal species that needed to be classified. To add a little more animals to the data, we added an additional 21 unique classes, so we were now working with our own 111-animals dataset. We also added approx. 1000 AI generated images for all classes to get a more diverse dataset and also improve the performance of the model.
|
| 6 |
|
| 7 |
+
The employed model is ResNet18, which was trained on the dataset using transfer learning techniques.
|
| 8 |
Translation of animal names by [deep-translator](https://pypi.org/project/deep-translator/).
|
| 9 |
|
| 10 |
## Usage 🦎
|
src/results/gradcam_video.mp4
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d88ec14ff35116bf5d8bd65454616aba242d8f79bde4dcbd717aabbcc910670a
|
| 3 |
+
size 917687
|
src/results/infer_image.png
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
src/results/models/best_model.pth
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e6a3f852efacebef8dee4ba74c0a73a7f33bf2180c4272dbf233a5c6157d7531
|
| 3 |
+
size 45015274
|
src/util.py
CHANGED
|
@@ -1,83 +1,9 @@
|
|
| 1 |
import torchvision.transforms as transforms
|
| 2 |
-
from torch.utils.data import DataLoader, Dataset
|
| 3 |
-
from sklearn.preprocessing import LabelEncoder
|
| 4 |
-
from tqdm import tqdm
|
| 5 |
-
from PIL import Image
|
| 6 |
import torch
|
| 7 |
-
import imagehash
|
| 8 |
-
ImageCache = None
|
| 9 |
|
| 10 |
-
class AnimalDataset(Dataset):
|
| 11 |
-
def __init__(self, df, transform=None):
|
| 12 |
-
self.paths = df["path"].values
|
| 13 |
-
self.targets = df["target"].values
|
| 14 |
-
self.encoded_target = df['encoded_target'].values
|
| 15 |
-
self.transform = transform
|
| 16 |
-
self.images = []
|
| 17 |
-
for path in tqdm(self.paths):
|
| 18 |
-
self.images.append(Image.open(path).convert("RGB").resize((224, 224)))
|
| 19 |
-
|
| 20 |
-
def __len__(self):
|
| 21 |
-
return len(self.paths)
|
| 22 |
-
|
| 23 |
-
def __getitem__(self, idx):
|
| 24 |
-
img = self.images[idx]
|
| 25 |
-
if self.transform:
|
| 26 |
-
img = self.transform(img)
|
| 27 |
-
target = self.targets[idx]
|
| 28 |
-
encoded_target = torch.tensor(self.encoded_target[idx]).type(torch.LongTensor)
|
| 29 |
-
return img, encoded_target, target
|
| 30 |
-
|
| 31 |
-
train_transform = transforms.Compose([
|
| 32 |
-
transforms.Resize((224,224)),
|
| 33 |
-
transforms.RandomHorizontalFlip(),
|
| 34 |
-
transforms.RandomRotation(10),
|
| 35 |
-
transforms.ToTensor(),
|
| 36 |
-
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 37 |
-
])
|
| 38 |
# Define the transformation pipeline
|
| 39 |
transform = transforms.Compose([
|
| 40 |
transforms.Resize((224,224)),
|
| 41 |
transforms.ToTensor(), # Convert the images to PyTorch tensors
|
| 42 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 43 |
-
])
|
| 44 |
-
|
| 45 |
-
class CustomImageCache:
|
| 46 |
-
def __init__(self, cache_size=50, debug=False):
|
| 47 |
-
self.cache = dict()
|
| 48 |
-
self.cache_size = 50
|
| 49 |
-
self.debug = debug
|
| 50 |
-
self.cache_hits = 0
|
| 51 |
-
self.cache_misses = 0
|
| 52 |
-
|
| 53 |
-
def __getitem__(self, image):
|
| 54 |
-
if isinstance(image, dict):
|
| 55 |
-
# Its the image and a mask as pillow both -> Combine them to one image
|
| 56 |
-
image = Image.blend(image["image"], image["mask"], alpha=0.5)
|
| 57 |
-
key = imagehash.average_hash(image)
|
| 58 |
-
|
| 59 |
-
if key in self.cache:
|
| 60 |
-
if self.debug: print("Cache hit!")
|
| 61 |
-
self.cache_hits += 1
|
| 62 |
-
return self.cache[key]
|
| 63 |
-
else:
|
| 64 |
-
if self.debug: print("Cache miss!")
|
| 65 |
-
self.cache_misses += 1
|
| 66 |
-
if len(self.cache.keys()) >= self.cache_size:
|
| 67 |
-
if self.debug: print("Cache full, popping item!")
|
| 68 |
-
self.cache.popitem()
|
| 69 |
-
self.cache[key] = image
|
| 70 |
-
return self.cache[key]
|
| 71 |
-
|
| 72 |
-
def __len__(self):
|
| 73 |
-
return len(self.cache.keys())
|
| 74 |
-
|
| 75 |
-
def print_info(self):
|
| 76 |
-
print(f"Cache size: {len(self)}")
|
| 77 |
-
print(f"Cache hits: {self.cache_hits}")
|
| 78 |
-
print(f"Cache misses: {self.cache_misses}")
|
| 79 |
-
|
| 80 |
-
def imageCacheWrapper(fn):
|
| 81 |
-
def wrapper(image):
|
| 82 |
-
return fn(ImageCache[image])
|
| 83 |
-
return wrapper
|
|
|
|
| 1 |
import torchvision.transforms as transforms
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import torch
|
|
|
|
|
|
|
| 3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
# Define the transformation pipeline
|
| 5 |
transform = transforms.Compose([
|
| 6 |
transforms.Resize((224,224)),
|
| 7 |
transforms.ToTensor(), # Convert the images to PyTorch tensors
|
| 8 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 9 |
+
])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|