Spaces:
Sleeping
Sleeping
| import torch | |
| from torchvision import transforms | |
| import matplotlib.pyplot as plt | |
| import cv2 | |
| import numpy as np | |
| from pytorch_grad_cam import GradCAM | |
| from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
| class split_white_and_gray(): | |
| def __init__(self,threshold=120) -> None: | |
| """ | |
| Initialize the class with a threshold value. | |
| Args: | |
| threshold (int, optional): The threshold value to be set. Defaults to 120. | |
| """ | |
| self.threshold = threshold | |
| def __call__(self,tensor): | |
| """ | |
| Apply thresholding to the input tensor and return the white matter, gray matter, and the original tensor. | |
| Parameters: | |
| tensor (torch.Tensor): The input tensor to be thresholded. | |
| Returns: | |
| torch.Tensor: The thresholded white matter. | |
| torch.Tensor: The thresholded gray matter. | |
| torch.Tensor: The original input tensor. | |
| """ | |
| tensor = (tensor*255).to(torch.int64) | |
| # Apply thresholding | |
| white_matter = torch.where(tensor >= self.threshold,tensor,0) | |
| white_matter = (white_matter/255).to(torch.float64) | |
| gray_matter = torch.where(tensor < self.threshold,tensor,0) | |
| gray_matter = (gray_matter/255).to(torch.float64) | |
| tensor = (tensor/255).to(torch.float64) | |
| return white_matter, gray_matter,tensor | |
| def showcam_withoutmask(original_image, grayscale_cam, image_title='Original Image'): | |
| """This function applies the CAM mask to the original image and returns the Matplotlib Figure object. | |
| :param original_image: The original image tensor in PyTorch format. | |
| :param grayscale_cam: The CAM mask tensor in PyTorch format. | |
| :return: Matplotlib Figure object. | |
| """ | |
| # Assuming you have two tensors: 'original_image' and 'cam_mask' | |
| # Make sure both tensors are on the CPU | |
| original_image = torch.squeeze(original_image).cpu() # torch.Size([3, 150, 150]) | |
| cam_mask = grayscale_cam.cpu() # torch.Size([1, 150, 150]) | |
| # Convert the tensors to NumPy arrays | |
| original_image_np = original_image.numpy() | |
| cam_mask_np = cam_mask.numpy() | |
| # Apply the mask to the original image | |
| masked_image = original_image_np * cam_mask_np | |
| # Normalize the masked_image | |
| masked_image_norm = (masked_image - np.min(masked_image)) / (np.max(masked_image) - np.min(masked_image)) | |
| # Create Matplotlib Figure | |
| fig, axes = plt.subplots(1, 3, figsize=(18, 5)) | |
| # Plot the original image | |
| axes[0].imshow(original_image_np.transpose(1, 2, 0)) # Assuming your original image is in (C, H, W) format | |
| axes[0].set_title(image_title) | |
| # Plot the CAM mask | |
| axes[1].imshow(cam_mask_np[0], cmap='jet') # Assuming your mask is grayscale | |
| axes[1].set_title('CAM Mask') | |
| # Plot the overlay (normalized) | |
| axes[2].imshow(masked_image_norm.transpose(1, 2, 0)) # Assuming your original image is in (C, H, W) format | |
| axes[2].set_title('Overlay (Normalized)') | |
| return fig | |
| def showcam_withmask(img_tensor: torch.Tensor, | |
| mask_tensor: torch.Tensor, | |
| use_rgb: bool = False, | |
| colormap: int = cv2.COLORMAP_JET, | |
| image_weight: float = 0.5, | |
| image_title: str = 'Original Image') -> plt.Figure: | |
| """ This function overlays the CAM mask on the image as a heatmap and returns the Figure object. | |
| By default, the heatmap is in BGR format. | |
| :param img_tensor: The base image tensor in PyTorch format. | |
| :param mask_tensor: The CAM mask tensor in PyTorch format. | |
| :param use_rgb: Whether to use an RGB or BGR heatmap; set to True if 'img_tensor' is in RGB format. | |
| :param colormap: The OpenCV colormap to be used. | |
| :param image_weight: The final result is image_weight * img + (1-image_weight) * mask. | |
| :return: Matplotlib Figure object. | |
| """ | |
| # Convert PyTorch tensors to NumPy arrays | |
| img = img_tensor.cpu().numpy().transpose(1, 2, 0) | |
| mask = mask_tensor.cpu().numpy() | |
| # Convert the mask to a single-channel image | |
| mask_single_channel = np.uint8(255 * mask[0]) | |
| heatmap = cv2.applyColorMap(mask_single_channel, colormap) | |
| if use_rgb: | |
| heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) | |
| heatmap = np.float32(heatmap) / 255 | |
| if np.max(img) > 1: | |
| raise Exception("The input image should be in the range [0, 1]") | |
| if image_weight < 0 or image_weight > 1: | |
| raise Exception(f"image_weight should be in the range [0, 1]. Got: {image_weight}") | |
| cam = (1 - image_weight) * heatmap + image_weight * img | |
| cam = cam / np.max(cam) | |
| # Create Matplotlib Figure | |
| fig, axes = plt.subplots(1, 3, figsize=(15, 5)) | |
| # Plot the original image | |
| axes[0].imshow(img) | |
| axes[0].set_title(image_title) | |
| # Plot the CAM mask | |
| axes[1].imshow(mask[0], cmap='jet') | |
| axes[1].set_title('CAM Mask') | |
| # Plot the overlay | |
| axes[2].imshow(cam) | |
| axes[2].set_title('Overlay') | |
| return fig | |
| def predict_and_gradcam(pil_image, model, target=100, plot_type='withmask'): | |
| transform = transforms.Compose([ | |
| transforms.Resize((150, 150)), | |
| transforms.Grayscale(num_output_channels=3), | |
| transforms.ToTensor(), | |
| split_white_and_gray(120), | |
| ]) | |
| white_matter_tensor, gray_matter_tensor, origin_tensor = transform(pil_image) | |
| white_matter_tensor, gray_matter_tensor, origin_tensor = white_matter_tensor.unsqueeze(0).to(torch.float32),\ | |
| gray_matter_tensor.unsqueeze(0).to(torch.float32),\ | |
| origin_tensor.unsqueeze(0).to(torch.float32) | |
| def calculate_gradcammask(model_grad, input_tensor): | |
| target_layer = [model_grad.layer4[-1]] | |
| gradcam = GradCAM(model=model_grad, target_layers=target_layer) | |
| targets = [ClassifierOutputTarget(target)] | |
| grayscale_cam = gradcam(input_tensor=input_tensor, targets=targets, aug_smooth=True, eigen_smooth=True) | |
| grayscale_cam = torch.tensor(grayscale_cam) | |
| return grayscale_cam | |
| origin_model = model.resnet18_model | |
| white_model = model.whitematter_resnet18_model | |
| gray_model = model.graymatter_resnet18_model | |
| origin_cam = calculate_gradcammask(origin_model, origin_tensor) | |
| white_cam = calculate_gradcammask(white_model, white_matter_tensor) | |
| gray_cam = calculate_gradcammask(gray_model, gray_matter_tensor) | |
| class_idx = {0: 'Moderate Demented', 1: 'Mild Demented', 2: 'Very Mild Demented', 3: 'Non Demented'} | |
| prediction = model(white_matter_tensor, gray_matter_tensor, origin_tensor) | |
| predicted_class_index = torch.argmax(prediction).item() | |
| predicted_class_label = class_idx[predicted_class_index] | |
| if plot_type == 'withmask': | |
| return predicted_class_label, showcam_withmask(torch.squeeze(origin_tensor), origin_cam),\ | |
| showcam_withmask(torch.squeeze(white_matter_tensor), white_cam, image_title='White Matter'),\ | |
| showcam_withmask(torch.squeeze(gray_matter_tensor), gray_cam, image_title='Gray Matter') | |
| elif plot_type == 'withoutmask': | |
| return predicted_class_label, showcam_withoutmask(torch.squeeze(origin_tensor),origin_cam),\ | |
| showcam_withoutmask(torch.squeeze(white_matter_tensor),white_cam, image_title='White Matter'),\ | |
| showcam_withoutmask(torch.squeeze(gray_matter_tensor),gray_cam , image_title='Gray Matter') | |
| else: | |
| raise ValueError("plot_type must be either 'withmask' or 'withoutmask'") | |