Spaces:
Sleeping
Sleeping
File size: 6,353 Bytes
12c10e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import cv2
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms
from PIL import Image
from torch.autograd import Variable
def generate_gradcam_heatmap(model, image, device, target_class=None):
# Define the target layer (last convolutional layer in VGG16)
target_layer = model.features[29] # VGG16's last convolutional layer
# Hook to get the feature maps and gradients
def save_feature_maps(module, input, output):
feature_maps.append(output)
def save_grads(module, input, output):
gradients.append(output[0])
# Register hooks for feature maps and gradients
feature_maps = []
gradients = []
target_layer.register_forward_hook(save_feature_maps)
target_layer.register_backward_hook(save_grads)
# Forward pass: get the model's output
image = image.unsqueeze(0) # Add batch dimension
output = model(image)
if target_class is None:
target_class = torch.argmax(output, dim=1) # Get the predicted class
if target_class.item() == 2:
return False
# print(target_class.item())
# Zero the gradients of the model's parameters
model.zero_grad()
# Backward pass: compute gradients w.r.t the target class
target_class_score = output[0, target_class]
target_class_score.backward()
# Get the feature map and gradients
feature_map = feature_maps[0]
gradient = gradients[0]
# Compute the weights
weights = torch.mean(gradient, dim=[2, 3], keepdim=True) # Global average pooling
weighted_feature_map = weights * feature_map # Weighting the feature map
# Sum along the channels to get a single heatmap
heatmap = torch.sum(weighted_feature_map, dim=1, keepdim=False)
# Apply ReLU to the heatmap
heatmap = F.relu(heatmap)
# Normalize the heatmap to [0, 1]
heatmap = heatmap - torch.min(heatmap)
heatmap = heatmap / torch.max(heatmap)
# Resize the heatmap to match the input image size
heatmap = heatmap.squeeze().cpu().detach().numpy()
heatmap = np.uint8(255 * heatmap) # Convert to 0-255 scale
heatmap = Image.fromarray(heatmap).resize((224, 224))
return heatmap
def overlay_heatmap(image_path, model, device):
# Open the image and convert it to RGB
image = Image.open(image_path).convert('RGB')
# Preprocess the image
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image_tensor = transform(image).to(device)
# Generate Grad-CAM heatmap
heatmap = generate_gradcam_heatmap(model, image_tensor, device = device)
if(heatmap != False):
# Convert original image to numpy array
image = np.array(image)
# Resize the heatmap to the image size
heatmap_resized = np.array(heatmap.resize(image.shape[1::-1]))
# Normalize the heatmap to the same scale as the original image
heatmap_resized = np.uint8(255 * (heatmap_resized / np.max(heatmap_resized)))
# Apply the heatmap (overlay) on the original image
overlay = cv2.applyColorMap(heatmap_resized, cv2.COLORMAP_JET)
overlay = np.float32(overlay) / 255
result = np.float32(image) / 255
result = result + 0.4 * overlay # Adjust the intensity of the heatmap
# Clip the result to avoid values outside [0, 1]
result = np.clip(result, 0, 1)
return result
# Show the result
# plt.imshow(result)
# plt.axis('off')
# plt.show()
else:
return image
import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
def overlay_heatmap1(image_path, model, device):
# Open the image and convert it to RGB
image = Image.open(image_path).convert('RGB')
# Preprocess the image
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image_tensor = transform(image).to(device)
# Generate Grad-CAM heatmap
heatmap = generate_gradcam_heatmap(model, image_tensor, device = device)
if(heatmap != False):
# Convert original image to numpy array
image = np.array(image)
# Resize the heatmap to the image size
heatmap_resized = np.array(heatmap.resize(image.shape[1::-1]))
# Normalize the heatmap to the same scale as the original image
heatmap_resized = np.uint8(255 * (heatmap_resized / np.max(heatmap_resized)))
# Threshold the heatmap to create a binary mask for detected areas
_, binary_mask = cv2.threshold(heatmap_resized, 127, 255, cv2.THRESH_BINARY)
# Apply Gaussian Blur to create a soft fading effect on the binary mask
blurred_mask = cv2.GaussianBlur(binary_mask, (21, 21), 0)
# Normalize the blurred mask to range [0, 1] (for blending)
blurred_mask = blurred_mask / np.max(blurred_mask)
# Create a light blue background (RGB values)
light_blue = np.array([173, 0, 0], dtype=np.uint8) # Light blue color in RGB
blue_background = np.full_like(image, light_blue)
# Create a white overlay for the detected area with fading effect
fading_overlay = np.zeros_like(image, dtype=np.uint8)
for i in range(3): # For each color channel (RGB)
fading_overlay[:,:,i] = (blurred_mask * 255).astype(np.uint8)
# Combine the light blue background with the fading white overlay
result_image = cv2.addWeighted(blue_background, 0.8, fading_overlay, 0.9, 0)
# Combine with original image to preserve the brain structure while applying the blue tint
result_image = cv2.addWeighted(result_image, 0.8, image, 0.8, 0)
result_image = cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB)
return result_image
# Display the result image
plt.imshow(cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB))
plt.axis('off')
plt.show()
else:
return image
# overlay_heatmap1(image_path, model, device)
# overlay_heatmap(image_path, model, device)
|