OverMind0 commited on
Commit
f224ba4
·
verified ·
1 Parent(s): 2d81e54

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -0
app.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+ import numpy as np
7
+ import gradio as gr
8
+ import cv2
9
+ import matplotlib.pyplot as plt
10
+
11
+ # Define your CNN model
12
+ class TeethCNN(nn.Module):
13
+ def __init__(self, num_classes=7):
14
+ super(TeethCNN, self).__init__()
15
+ self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
16
+ self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
17
+ self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
18
+ self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
19
+ self.pool = nn.MaxPool2d(2, 2)
20
+ self.dropout = nn.Dropout(0.3)
21
+ self.fc1 = nn.Linear(256 * 14 * 14, 256)
22
+ self.fc2 = nn.Linear(256, num_classes)
23
+
24
+ def forward(self, x):
25
+ x = self.pool(F.relu(self.conv1(x)))
26
+ x = self.pool(F.relu(self.conv2(x)))
27
+ x = self.pool(F.relu(self.conv3(x)))
28
+ x = self.pool(F.relu(self.conv4(x)))
29
+ x = x.view(x.size(0), -1)
30
+ x = self.dropout(F.relu(self.fc1(x)))
31
+ x = self.fc2(x)
32
+ return x
33
+
34
+ # GradCAM logic
35
+ class GradCAM:
36
+ def __init__(self, model, target_layer):
37
+ self.model = model
38
+ self.target_layer = target_layer
39
+ self.gradients = None
40
+ self.activations = None
41
+ self._register_hooks()
42
+
43
+ def _register_hooks(self):
44
+ def forward_hook(module, input, output):
45
+ self.activations = output
46
+
47
+ def backward_hook(module, grad_input, grad_output):
48
+ self.gradients = grad_output[0]
49
+
50
+ self.target_layer.register_forward_hook(forward_hook)
51
+ self.target_layer.register_full_backward_hook(backward_hook)
52
+
53
+ def generate(self, input_tensor, class_idx=None):
54
+ self.model.eval()
55
+ output = self.model(input_tensor)
56
+ if class_idx is None:
57
+ class_idx = output.argmax(dim=1).item()
58
+ loss = output[:, class_idx]
59
+ self.model.zero_grad()
60
+ loss.backward()
61
+ gradients = self.gradients[0]
62
+ activations = self.activations[0]
63
+ weights = gradients.mean(dim=(1, 2))
64
+ cam = torch.zeros(activations.shape[1:], device=activations.device)
65
+ for i, w in enumerate(weights):
66
+ cam += w * activations[i]
67
+ cam = torch.relu(cam)
68
+ cam = cam - cam.min()
69
+ cam = cam / cam.max()
70
+ return cam.detach().cpu().numpy()
71
+
72
+ # Load model
73
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
74
+ class_names = ['CaS', 'CoS', 'Gum', 'MC', 'OC', 'OLP', 'OT']
75
+ model = TeethCNN(num_classes=len(class_names))
76
+ model.load_state_dict(torch.load("teeth_model_weights.pth", map_location=device))
77
+ model.to(device)
78
+ model.eval()
79
+
80
+ # Preprocessing
81
+ transform = transforms.Compose([
82
+ transforms.Resize((224, 224)),
83
+ transforms.ToTensor(),
84
+ transforms.Normalize([0.5, 0.5, 0.5],
85
+ [0.5, 0.5, 0.5])
86
+ ])
87
+
88
+ def predict_with_gradcam(image):
89
+ image = image.convert("RGB")
90
+ input_tensor = transform(image).unsqueeze(0).to(device)
91
+ output = model(input_tensor)
92
+ pred_idx = output.argmax(dim=1).item()
93
+ pred_label = class_names[pred_idx]
94
+
95
+ gradcam = GradCAM(model, model.conv4)
96
+ cam = gradcam.generate(input_tensor)
97
+ cam_resized = cv2.resize(cam, (224, 224))
98
+
99
+ img_np = np.array(image.resize((224, 224))) / 255.0
100
+ cam_overlay = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET)
101
+ cam_overlay = cv2.cvtColor(cam_overlay, cv2.COLOR_BGR2RGB)
102
+ overlay = (0.5 * img_np + 0.5 * cam_overlay / 255.0)
103
+ overlay = np.clip(overlay, 0, 1)
104
+
105
+ return pred_label, overlay
106
+
107
+ # Gradio interface
108
+ interface = gr.Interface(
109
+ fn=predict_with_gradcam,
110
+ inputs=gr.Image(type="pil"),
111
+ outputs=["label", "image"],
112
+ title="🦷 Teeth Disease Classifier with Grad-CAM",
113
+ description="Upload an image of teeth and the model will predict the disease with Grad-CAM visualization."
114
+ )
115
+
116
+ interface.launch()