Brain_Tumor_Detection / saliency.py
Subhajit01's picture
upload 7 files
12c10e4 verified
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms
from PIL import Image
import cv2
def generate_saliency(model, image_path, device):
# Load and preprocess the image
image = Image.open(image_path).convert('RGB')
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = transform(image).unsqueeze(0).to(device)
input_tensor.requires_grad = True # Enable gradient computation on input
output = model(input_tensor)
# Select the class with the highest score
target_class = torch.argmax(output, dim=1).item()
# Backpropagate the gradients of the target class
model.zero_grad()
target_score = output[0, target_class]
target_score.backward()
# Extract gradients of the input tensor
gradients = input_tensor.grad.data.cpu().numpy()[0] # Shape: (C, H, W)
# Convert to grayscale saliency map (absolute max gradient across channels)
saliency_map = np.max(np.abs(gradients), axis=0)
# Normalize the saliency map to [0, 1]
saliency_map = (saliency_map - saliency_map.min()) / (saliency_map.max() - saliency_map.min())
return display_saliency_map(image_path, saliency_map)
def display_saliency_map(image_path, saliency_map):
# Open the original image and resize it to match the saliency map's dimensions
original_image = Image.open(image_path).convert('RGB')
original_image = original_image.resize(saliency_map.shape[::-1][0:2]) # Only take width, height (224, 224)
original_image = np.array(original_image)
# Normalize the saliency map to the range [0, 255]
saliency_map_uint8 = (saliency_map * 255).astype(np.uint8)
# Create a blue-tinted background image
blue_image = np.zeros_like(original_image)
blue_image[:, :, 2] = 255 # Set the blue channel to maximum
# Apply the saliency map as a mask (this will give the saliency map in reddish tones)
saliency_colored = cv2.applyColorMap(saliency_map_uint8, cv2.COLORMAP_HOT) # Saliency in reddish tones
saliency_colored = cv2.cvtColor(saliency_colored, cv2.COLOR_BGR2RGB) # Convert BGR to RGB
# Normalize saliency map to create a blending mask
saliency_mask = cv2.normalize(saliency_map, None, 0, 1, cv2.NORM_MINMAX) # Normalize mask to [0, 1]
# Reshape the saliency mask to have a 3rd dimension (224, 224, 1)
saliency_mask = np.expand_dims(saliency_mask, axis=-1)
# Now, we need to broadcast the saliency mask to have 3 channels for blending with images
saliency_mask = np.repeat(saliency_mask, 3, axis=-1) # Convert to shape (224, 224, 3)
# Create a blended background by combining the blue image and the original image
blended_background = cv2.addWeighted(blue_image, 0.3, original_image, 0.7, 0)
# Overlay the saliency map on the background
result_image = blended_background * (1 - saliency_mask) + saliency_colored * saliency_mask
result_image = result_image.astype(np.uint8) # Ensure proper format for visualization
return result_image