| import torch |
| import numpy as np |
|
|
| |
| |
| import cv2 |
| import re |
|
|
| from .image_utils import show_cam_on_image, show_overlapped_cam |
|
|
|
|
| def rn_relevance( |
| image, |
| target_features, |
| img_encoder, |
| method, |
| device, |
| neg_saliency=False, |
| img_dim=224, |
| ): |
| target_layers = [img_encoder.layer4[-1]] |
|
|
| cam = method( |
| model=img_encoder, |
| target_layers=target_layers, |
| use_cuda=torch.cuda.is_available() and device != "cpu", |
| ) |
|
|
| if neg_saliency: |
| target_encoding = -target_features |
| else: |
| target_encoding = target_features |
|
|
| image_relevance = cam(input_tensor=image, target_encoding=target_encoding)[ |
| 0 |
| ].squeeze() |
| image_relevance = torch.FloatTensor(image_relevance) |
|
|
| resize_dim = int(list(image_relevance.shape)[0]) |
|
|
| image_relevance = image_relevance.reshape(1, 1, resize_dim, resize_dim) |
|
|
| |
| image_relevance = torch.nn.functional.interpolate( |
| image_relevance, size=img_dim, mode="bilinear" |
| ) |
| image_relevance = image_relevance.reshape(img_dim, img_dim).data.cpu().numpy() |
| image_relevance = (image_relevance - image_relevance.min()) / ( |
| 1e-7 + image_relevance.max() - image_relevance.min() |
| ) |
| image = image[0].permute(1, 2, 0).data.cpu().numpy() |
| image = (image - image.min()) / (image.max() - image.min()) |
|
|
| return image_relevance, image |
|
|
|
|
| def interpret_rn( |
| image, |
| target_features, |
| img_encoder, |
| method, |
| device, |
| neg_saliency=False, |
| img_dim=224, |
| ): |
| image_relevance, image = rn_relevance( |
| image, |
| target_features, |
| img_encoder, |
| method, |
| device, |
| neg_saliency=neg_saliency, |
| img_dim=img_dim, |
| ) |
| vis = show_cam_on_image(image, image_relevance, neg_saliency=neg_saliency) |
| vis = np.uint8(255 * vis) |
| vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR) |
|
|
| return vis |
| |
|
|
|
|
| def interpret_rn_overlapped( |
| image, target_features, img_encoder, method, device, img_dim=224 |
| ): |
| pos_image_relevance, _ = rn_relevance( |
| image, |
| target_features, |
| img_encoder, |
| method, |
| device, |
| neg_saliency=False, |
| img_dim=img_dim, |
| ) |
| neg_image_relevance, image = rn_relevance( |
| image, |
| target_features, |
| img_encoder, |
| method, |
| device, |
| neg_saliency=True, |
| img_dim=img_dim, |
| ) |
|
|
| vis = show_overlapped_cam(image, neg_image_relevance, pos_image_relevance) |
| vis = np.uint8(255 * vis) |
| vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR) |
|
|
| return vis |
| |
|
|
|
|
| def rn_perword_relevance( |
| image, |
| text, |
| clip_model, |
| clip_tokenizer, |
| method, |
| device, |
| masked_word="", |
| data_only=False, |
| img_dim=224, |
| ): |
| clip_model.eval() |
|
|
| main_text = clip_tokenizer(text).to(device) |
| |
| masked_text = re.sub(masked_word, "", text) |
| masked_text = clip_tokenizer(masked_text).to(device) |
|
|
| |
| main_text_features = clip_model.encode_text(main_text) |
| masked_text_features = clip_model.encode_text(masked_text) |
|
|
| |
| |
|
|
| main_text_features_norm = main_text_features.norm(dim=-1, keepdim=True) |
| main_text_features_new = main_text_features / main_text_features_norm |
|
|
| masked_text_features_norm = masked_text_features.norm(dim=-1, keepdim=True) |
| masked_text_features_new = masked_text_features / masked_text_features_norm |
|
|
| target_encoding = main_text_features_new - masked_text_features_new |
|
|
| target_layers = [clip_model.visual.layer4[-1]] |
|
|
| cam = method( |
| model=clip_model.visual, |
| target_layers=target_layers, |
| use_cuda=torch.cuda.is_available() and device != "cpu", |
| ) |
|
|
| |
|
|
| image_relevance = cam(input_tensor=image, target_encoding=target_encoding)[ |
| 0 |
| ].squeeze() |
| image_relevance = torch.FloatTensor(image_relevance) |
|
|
| resize_dim = int(list(image_relevance.shape)[0]) |
|
|
| image_relevance = image_relevance.reshape(1, 1, resize_dim, resize_dim) |
|
|
| |
| image_relevance = torch.nn.functional.interpolate( |
| image_relevance, size=img_dim, mode="bilinear" |
| ) |
| image_relevance = image_relevance.reshape(img_dim, img_dim).data.cpu().numpy() |
| image_relevance = (image_relevance - image_relevance.min()) / ( |
| 1e-7 + image_relevance.max() - image_relevance.min() |
| ) |
|
|
| if data_only: |
| return image_relevance |
|
|
| image = image[0].permute(1, 2, 0).data.cpu().numpy() |
| image = (image - image.min()) / (image.max() - image.min()) |
|
|
| return image_relevance |
|
|
|
|
| def interpret_perword_rn( |
| image, |
| text, |
| clip_model, |
| clip_tokenizer, |
| method, |
| device, |
| masked_word="", |
| data_only=False, |
| img_dim=224, |
| ): |
| image_relevance = rn_perword_relevance( |
| image, |
| text, |
| clip_model, |
| clip_tokenizer, |
| method, |
| device, |
| masked_word, |
| data_only=data_only, |
| img_dim=img_dim, |
| ) |
| vis = show_cam_on_image(image, image_relevance) |
| vis = np.uint8(255 * vis) |
| vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR) |
|
|
| return vis |
| |
|
|