File size: 4,376 Bytes
f224ba4 bf5ebd3 f224ba4 bf5ebd3 f224ba4 bf5ebd3 f224ba4 bf5ebd3 f224ba4 bf5ebd3 f224ba4 bf5ebd3 f224ba4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import numpy as np
import gradio as gr
import cv2
import matplotlib.pyplot as plt
# Define your CNN model
class TeethCNN(nn.Module):
def __init__(self, num_classes=7):
super(TeethCNN, self).__init__()
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.dropout = nn.Dropout(0.3)
self.fc1 = nn.Linear(256 * 14 * 14, 256)
self.fc2 = nn.Linear(256, num_classes)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = self.pool(F.relu(self.conv3(x)))
x = self.pool(F.relu(self.conv4(x)))
x = x.view(x.size(0), -1)
x = self.dropout(F.relu(self.fc1(x)))
x = self.fc2(x)
return x
# GradCAM logic
class GradCAM:
def __init__(self, model, target_layer):
self.model = model
self.target_layer = target_layer
self.gradients = None
self.activations = None
self._register_hooks()
def _register_hooks(self):
def forward_hook(module, input, output):
self.activations = output
def backward_hook(module, grad_input, grad_output):
self.gradients = grad_output[0]
self.target_layer.register_forward_hook(forward_hook)
self.target_layer.register_full_backward_hook(backward_hook)
def generate(self, input_tensor, class_idx=None):
self.model.eval()
output = self.model(input_tensor)
if class_idx is None:
class_idx = output.argmax(dim=1).item()
loss = output[:, class_idx]
self.model.zero_grad()
loss.backward()
gradients = self.gradients[0]
activations = self.activations[0]
weights = gradients.mean(dim=(1, 2))
cam = torch.zeros(activations.shape[1:], device=activations.device)
for i, w in enumerate(weights):
cam += w * activations[i]
cam = torch.relu(cam)
cam = cam - cam.min()
cam = cam / cam.max()
return cam.detach().cpu().numpy()
# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class_names = ['CaS', 'CoS', 'Gum', 'MC', 'OC', 'OLP', 'OT']
model = TeethCNN(num_classes=len(class_names))
model.load_state_dict(torch.load("teeth_model_weights.pth", map_location=device))
model.to(device)
model.eval()
# Preprocessing
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5],
[0.5, 0.5, 0.5])
])
def predict_with_gradcam(image):
image = image.convert("RGB")
input_tensor = transform(image).unsqueeze(0).to(device)
# Prediction
output = model(input_tensor)
pred_idx = output.argmax(dim=1).item()
pred_label = class_names[pred_idx]
# Prepare base image
img_np = np.array(image.resize((224, 224))) / 255.0
# Multiple layer Grad-CAMs
target_layers = [model.conv2, model.conv3, model.conv4]
visualizations = []
for layer in target_layers:
gradcam = GradCAM(model, layer)
cam = gradcam.generate(input_tensor)
cam_resized = cv2.resize(cam, (224, 224))
cam_overlay = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET)
cam_overlay = cv2.cvtColor(cam_overlay, cv2.COLOR_BGR2RGB)
overlay = (0.5 * img_np + 0.5 * cam_overlay / 255.0)
overlay = np.clip(overlay, 0, 1)
visualizations.append(overlay)
return pred_label, *visualizations
interface = gr.Interface(
fn=predict_with_gradcam,
inputs=gr.Image(type="pil"),
outputs=[
gr.Label(label="Predicted Class"),
gr.Image(label="Grad-CAM: Conv2"),
gr.Image(label="Grad-CAM: Conv3"),
gr.Image(label="Grad-CAM: Conv4")
],
title="🦷 Teeth Disease Classifier with Grad-CAM",
description="Upload a teeth image. The model predicts the class and shows Grad-CAM visualizations for multiple convolutional layers."
)
interface.launch()
|