teeth / app.py
OverMind0's picture
Update app.py
bf5ebd3 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import numpy as np
import gradio as gr
import cv2
import matplotlib.pyplot as plt
# Define your CNN model
class TeethCNN(nn.Module):
def __init__(self, num_classes=7):
super(TeethCNN, self).__init__()
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.dropout = nn.Dropout(0.3)
self.fc1 = nn.Linear(256 * 14 * 14, 256)
self.fc2 = nn.Linear(256, num_classes)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = self.pool(F.relu(self.conv3(x)))
x = self.pool(F.relu(self.conv4(x)))
x = x.view(x.size(0), -1)
x = self.dropout(F.relu(self.fc1(x)))
x = self.fc2(x)
return x
# GradCAM logic
class GradCAM:
def __init__(self, model, target_layer):
self.model = model
self.target_layer = target_layer
self.gradients = None
self.activations = None
self._register_hooks()
def _register_hooks(self):
def forward_hook(module, input, output):
self.activations = output
def backward_hook(module, grad_input, grad_output):
self.gradients = grad_output[0]
self.target_layer.register_forward_hook(forward_hook)
self.target_layer.register_full_backward_hook(backward_hook)
def generate(self, input_tensor, class_idx=None):
self.model.eval()
output = self.model(input_tensor)
if class_idx is None:
class_idx = output.argmax(dim=1).item()
loss = output[:, class_idx]
self.model.zero_grad()
loss.backward()
gradients = self.gradients[0]
activations = self.activations[0]
weights = gradients.mean(dim=(1, 2))
cam = torch.zeros(activations.shape[1:], device=activations.device)
for i, w in enumerate(weights):
cam += w * activations[i]
cam = torch.relu(cam)
cam = cam - cam.min()
cam = cam / cam.max()
return cam.detach().cpu().numpy()
# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class_names = ['CaS', 'CoS', 'Gum', 'MC', 'OC', 'OLP', 'OT']
model = TeethCNN(num_classes=len(class_names))
model.load_state_dict(torch.load("teeth_model_weights.pth", map_location=device))
model.to(device)
model.eval()
# Preprocessing
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5],
[0.5, 0.5, 0.5])
])
def predict_with_gradcam(image):
image = image.convert("RGB")
input_tensor = transform(image).unsqueeze(0).to(device)
# Prediction
output = model(input_tensor)
pred_idx = output.argmax(dim=1).item()
pred_label = class_names[pred_idx]
# Prepare base image
img_np = np.array(image.resize((224, 224))) / 255.0
# Multiple layer Grad-CAMs
target_layers = [model.conv2, model.conv3, model.conv4]
visualizations = []
for layer in target_layers:
gradcam = GradCAM(model, layer)
cam = gradcam.generate(input_tensor)
cam_resized = cv2.resize(cam, (224, 224))
cam_overlay = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET)
cam_overlay = cv2.cvtColor(cam_overlay, cv2.COLOR_BGR2RGB)
overlay = (0.5 * img_np + 0.5 * cam_overlay / 255.0)
overlay = np.clip(overlay, 0, 1)
visualizations.append(overlay)
return pred_label, *visualizations
interface = gr.Interface(
fn=predict_with_gradcam,
inputs=gr.Image(type="pil"),
outputs=[
gr.Label(label="Predicted Class"),
gr.Image(label="Grad-CAM: Conv2"),
gr.Image(label="Grad-CAM: Conv3"),
gr.Image(label="Grad-CAM: Conv4")
],
title="🦷 Teeth Disease Classifier with Grad-CAM",
description="Upload a teeth image. The model predicts the class and shows Grad-CAM visualizations for multiple convolutional layers."
)
interface.launch()