|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torchvision import transforms |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
import gradio as gr |
|
|
import cv2 |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
|
|
|
class TeethCNN(nn.Module): |
|
|
def __init__(self, num_classes=7): |
|
|
super(TeethCNN, self).__init__() |
|
|
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1) |
|
|
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) |
|
|
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) |
|
|
self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1) |
|
|
self.pool = nn.MaxPool2d(2, 2) |
|
|
self.dropout = nn.Dropout(0.3) |
|
|
self.fc1 = nn.Linear(256 * 14 * 14, 256) |
|
|
self.fc2 = nn.Linear(256, num_classes) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.pool(F.relu(self.conv1(x))) |
|
|
x = self.pool(F.relu(self.conv2(x))) |
|
|
x = self.pool(F.relu(self.conv3(x))) |
|
|
x = self.pool(F.relu(self.conv4(x))) |
|
|
x = x.view(x.size(0), -1) |
|
|
x = self.dropout(F.relu(self.fc1(x))) |
|
|
x = self.fc2(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class GradCAM: |
|
|
def __init__(self, model, target_layer): |
|
|
self.model = model |
|
|
self.target_layer = target_layer |
|
|
self.gradients = None |
|
|
self.activations = None |
|
|
self._register_hooks() |
|
|
|
|
|
def _register_hooks(self): |
|
|
def forward_hook(module, input, output): |
|
|
self.activations = output |
|
|
|
|
|
def backward_hook(module, grad_input, grad_output): |
|
|
self.gradients = grad_output[0] |
|
|
|
|
|
self.target_layer.register_forward_hook(forward_hook) |
|
|
self.target_layer.register_full_backward_hook(backward_hook) |
|
|
|
|
|
def generate(self, input_tensor, class_idx=None): |
|
|
self.model.eval() |
|
|
output = self.model(input_tensor) |
|
|
if class_idx is None: |
|
|
class_idx = output.argmax(dim=1).item() |
|
|
loss = output[:, class_idx] |
|
|
self.model.zero_grad() |
|
|
loss.backward() |
|
|
gradients = self.gradients[0] |
|
|
activations = self.activations[0] |
|
|
weights = gradients.mean(dim=(1, 2)) |
|
|
cam = torch.zeros(activations.shape[1:], device=activations.device) |
|
|
for i, w in enumerate(weights): |
|
|
cam += w * activations[i] |
|
|
cam = torch.relu(cam) |
|
|
cam = cam - cam.min() |
|
|
cam = cam / cam.max() |
|
|
return cam.detach().cpu().numpy() |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
class_names = ['CaS', 'CoS', 'Gum', 'MC', 'OC', 'OLP', 'OT'] |
|
|
model = TeethCNN(num_classes=len(class_names)) |
|
|
model.load_state_dict(torch.load("teeth_model_weights.pth", map_location=device)) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize([0.5, 0.5, 0.5], |
|
|
[0.5, 0.5, 0.5]) |
|
|
]) |
|
|
|
|
|
def predict_with_gradcam(image): |
|
|
image = image.convert("RGB") |
|
|
input_tensor = transform(image).unsqueeze(0).to(device) |
|
|
|
|
|
|
|
|
output = model(input_tensor) |
|
|
pred_idx = output.argmax(dim=1).item() |
|
|
pred_label = class_names[pred_idx] |
|
|
|
|
|
|
|
|
img_np = np.array(image.resize((224, 224))) / 255.0 |
|
|
|
|
|
|
|
|
target_layers = [model.conv2, model.conv3, model.conv4] |
|
|
visualizations = [] |
|
|
|
|
|
for layer in target_layers: |
|
|
gradcam = GradCAM(model, layer) |
|
|
cam = gradcam.generate(input_tensor) |
|
|
cam_resized = cv2.resize(cam, (224, 224)) |
|
|
cam_overlay = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET) |
|
|
cam_overlay = cv2.cvtColor(cam_overlay, cv2.COLOR_BGR2RGB) |
|
|
overlay = (0.5 * img_np + 0.5 * cam_overlay / 255.0) |
|
|
overlay = np.clip(overlay, 0, 1) |
|
|
visualizations.append(overlay) |
|
|
|
|
|
return pred_label, *visualizations |
|
|
|
|
|
|
|
|
interface = gr.Interface( |
|
|
fn=predict_with_gradcam, |
|
|
inputs=gr.Image(type="pil"), |
|
|
outputs=[ |
|
|
gr.Label(label="Predicted Class"), |
|
|
gr.Image(label="Grad-CAM: Conv2"), |
|
|
gr.Image(label="Grad-CAM: Conv3"), |
|
|
gr.Image(label="Grad-CAM: Conv4") |
|
|
], |
|
|
title="🦷 Teeth Disease Classifier with Grad-CAM", |
|
|
description="Upload a teeth image. The model predicts the class and shows Grad-CAM visualizations for multiple convolutional layers." |
|
|
) |
|
|
|
|
|
|
|
|
interface.launch() |
|
|
|