Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -11,6 +11,8 @@ import torch.nn.functional as F
|
|
| 11 |
import matplotlib.pyplot as plt
|
| 12 |
from PIL import Image, ImageEnhance
|
| 13 |
import torchvision.transforms as transforms
|
|
|
|
|
|
|
| 14 |
|
| 15 |
ssl._create_default_https_context = lambda: ssl.create_default_context(cafile=certifi.where())
|
| 16 |
|
|
@@ -19,34 +21,123 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
| 19 |
|
| 20 |
# Number of classes
|
| 21 |
num_classes = 6
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
# Load the pre-trained ResNet model
|
| 24 |
-
model = models.
|
| 25 |
-
for param in model.parameters():
|
| 26 |
-
param.requires_grad = False # Freeze feature extractor
|
| 27 |
|
| 28 |
# Modify the classifier for 6 classes with an additional hidden layer
|
| 29 |
-
model.fc = nn.Sequential(
|
| 30 |
-
|
| 31 |
-
)
|
|
|
|
|
|
|
| 32 |
|
| 33 |
# Load trained weights
|
| 34 |
-
model.load_state_dict(torch.load('
|
| 35 |
model.eval()
|
| 36 |
|
| 37 |
# Class labels
|
| 38 |
-
class_labels = ['bird', 'cat', 'deer', 'dog', 'frog', 'horse']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
# Image transformation function
|
| 41 |
def transform_image(image):
|
| 42 |
"""Preprocess the input image."""
|
| 43 |
mean, std = [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]
|
| 44 |
img_size=224
|
| 45 |
-
transform = transforms.Compose([
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
img_tensor = transform(image).unsqueeze(0).to(device)
|
| 52 |
return img_tensor
|
|
@@ -109,18 +200,37 @@ def predict(image, brightness, contrast, hue, overlay_image, alpha):
|
|
| 109 |
with torch.no_grad():
|
| 110 |
output = model(image_tensor)
|
| 111 |
probabilities = F.softmax(output, dim=1).cpu().numpy()[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
# Generate Bar Chart
|
| 114 |
with plt.xkcd():
|
| 115 |
-
fig, ax = plt.subplots(figsize=(
|
| 116 |
-
|
|
|
|
|
|
|
|
|
|
| 117 |
ax.set_ylabel("Probability")
|
| 118 |
ax.set_title("Class Probabilities")
|
| 119 |
ax.set_ylim([0, 1])
|
| 120 |
-
|
|
|
|
|
|
|
| 121 |
ax.text(i, v + 0.02, f"{v:.2f}", ha='center', fontsize=10)
|
| 122 |
|
| 123 |
-
return final_image, fig
|
| 124 |
|
| 125 |
# Gradio Interface
|
| 126 |
with gr.Blocks() as interface:
|
|
@@ -137,10 +247,11 @@ with gr.Blocks() as interface:
|
|
| 137 |
|
| 138 |
with gr.Column():
|
| 139 |
processed_image = gr.Image(label="Final Processed Image")
|
|
|
|
| 140 |
bar_chart = gr.Plot(label="Class Probabilities")
|
| 141 |
|
| 142 |
inputs = [image_input, brightness, contrast, hue, overlay_input, alpha]
|
| 143 |
-
outputs = [processed_image, bar_chart]
|
| 144 |
|
| 145 |
# Event listeners for real-time updates
|
| 146 |
image_input.change(predict, inputs=inputs, outputs=outputs)
|
|
|
|
| 11 |
import matplotlib.pyplot as plt
|
| 12 |
from PIL import Image, ImageEnhance
|
| 13 |
import torchvision.transforms as transforms
|
| 14 |
+
import urllib
|
| 15 |
+
import json
|
| 16 |
|
| 17 |
ssl._create_default_https_context = lambda: ssl.create_default_context(cafile=certifi.where())
|
| 18 |
|
|
|
|
| 21 |
|
| 22 |
# Number of classes
|
| 23 |
num_classes = 6
|
| 24 |
+
'''
|
| 25 |
+
resnet imagenet
|
| 26 |
+
'''
|
| 27 |
+
url = "https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json"
|
| 28 |
+
with urllib.request.urlopen(url) as f:
|
| 29 |
+
imagenet_classes = json.load(f)
|
| 30 |
+
|
| 31 |
+
## Convert to dictionary format {0: "tench", 1: "goldfish", ..., 999: "toilet tissue"}
|
| 32 |
+
cifar10_classes = {int(k): v[1] for k, v in imagenet_classes.items()}
|
| 33 |
|
| 34 |
# Load the pre-trained ResNet model
|
| 35 |
+
model = models.resnet152(pretrained=True)
|
|
|
|
|
|
|
| 36 |
|
| 37 |
# Modify the classifier for 6 classes with an additional hidden layer
|
| 38 |
+
# model.fc = nn.Sequential(
|
| 39 |
+
# nn.Linear(model.fc.in_features, 512),
|
| 40 |
+
# nn.ReLU(),
|
| 41 |
+
# nn.Linear(512, num_classes)
|
| 42 |
+
# )
|
| 43 |
|
| 44 |
# Load trained weights
|
| 45 |
+
# model.load_state_dict(torch.load('model_old.pth', map_location=torch.device('cpu')))
|
| 46 |
model.eval()
|
| 47 |
|
| 48 |
# Class labels
|
| 49 |
+
# class_labels = ['bird', 'cat', 'deer', 'dog', 'frog', 'horse']
|
| 50 |
+
class_labels = [cifar10_classes[i] for i in range(len(cifar10_classes))]
|
| 51 |
+
|
| 52 |
+
class MultiLayerGradCAM:
|
| 53 |
+
def __init__(self, model, target_layers=None):
|
| 54 |
+
self.model = model
|
| 55 |
+
self.target_layers = target_layers if target_layers else ['layer4']
|
| 56 |
+
self.activations = []
|
| 57 |
+
self.gradients = []
|
| 58 |
+
self.handles = []
|
| 59 |
+
|
| 60 |
+
# Register hooks
|
| 61 |
+
for name, module in self.model.named_modules():
|
| 62 |
+
if name in self.target_layers:
|
| 63 |
+
self.handles.append(
|
| 64 |
+
module.register_forward_hook(self._forward_hook)
|
| 65 |
+
)
|
| 66 |
+
self.handles.append(
|
| 67 |
+
module.register_backward_hook(self._backward_hook)
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
def _forward_hook(self, module, input, output):
|
| 71 |
+
self.activations.append(output.detach())
|
| 72 |
+
|
| 73 |
+
def _backward_hook(self, module, grad_input, grad_output):
|
| 74 |
+
self.gradients.append(grad_output[0].detach())
|
| 75 |
+
|
| 76 |
+
def _find_layer(self, layer_name):
|
| 77 |
+
for name, module in self.model.named_modules():
|
| 78 |
+
if name == layer_name:
|
| 79 |
+
return module
|
| 80 |
+
raise ValueError(f"Layer {layer_name} not found in model")
|
| 81 |
+
|
| 82 |
+
def generate(self, input_tensor, target_class=None):
|
| 83 |
+
device = next(self.model.parameters()).device
|
| 84 |
+
self.model.zero_grad()
|
| 85 |
+
|
| 86 |
+
# Forward pass
|
| 87 |
+
output = self.model(input_tensor.to(device))
|
| 88 |
+
pred_class = torch.argmax(output).item() if target_class is None else target_class
|
| 89 |
+
|
| 90 |
+
# Backward pass
|
| 91 |
+
one_hot = torch.zeros_like(output)
|
| 92 |
+
one_hot[0][pred_class] = 1
|
| 93 |
+
output.backward(gradient=one_hot)
|
| 94 |
+
|
| 95 |
+
# Process activations and gradients
|
| 96 |
+
heatmaps = []
|
| 97 |
+
for act, grad in zip(self.activations, reversed(self.gradients)):
|
| 98 |
+
# Compute weights
|
| 99 |
+
weights = F.adaptive_avg_pool2d(grad, 1)
|
| 100 |
+
|
| 101 |
+
# Create weighted combination of activation maps
|
| 102 |
+
cam = torch.mul(act, weights).sum(dim=1, keepdim=True)
|
| 103 |
+
cam = F.relu(cam)
|
| 104 |
+
print(cam.shape)
|
| 105 |
+
# Upsample to input size
|
| 106 |
+
cam = F.interpolate(cam, size=input_tensor.shape[2:],
|
| 107 |
+
mode='bilinear', align_corners=False)
|
| 108 |
+
heatmaps.append(cam.squeeze().cpu().numpy())
|
| 109 |
+
|
| 110 |
+
# Combine heatmaps from different layers
|
| 111 |
+
combined_heatmap = np.mean(heatmaps, axis=0)
|
| 112 |
+
# print(combined_heatmap.shape)
|
| 113 |
+
# Normalize
|
| 114 |
+
combined_heatmap = np.maximum(combined_heatmap, 0)
|
| 115 |
+
combined_heatmap = (combined_heatmap - combined_heatmap.min()) / \
|
| 116 |
+
(combined_heatmap.max() - combined_heatmap.min() + 1e-10)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
return combined_heatmap, pred_class
|
| 120 |
+
|
| 121 |
+
def __del__(self):
|
| 122 |
+
for handle in self.handles:
|
| 123 |
+
handle.remove()
|
| 124 |
+
|
| 125 |
+
gradcam = MultiLayerGradCAM(model, target_layers=['layer3', 'layer4'])
|
| 126 |
|
| 127 |
# Image transformation function
|
| 128 |
def transform_image(image):
|
| 129 |
"""Preprocess the input image."""
|
| 130 |
mean, std = [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]
|
| 131 |
img_size=224
|
| 132 |
+
transform = transforms.Compose([ #IMAGENET
|
| 133 |
+
transforms.Resize(256), # Resize shorter side to 256, keeping aspect ratio
|
| 134 |
+
transforms.CenterCrop(224), # Crop the center 224x224 region
|
| 135 |
+
transforms.ToTensor(), # Convert to tensor (scales to [0,1])
|
| 136 |
+
transforms.Normalize( # Normalize using ImageNet mean & std
|
| 137 |
+
mean=[0.485, 0.456, 0.406],
|
| 138 |
+
std=[0.229, 0.224, 0.225]
|
| 139 |
+
)
|
| 140 |
+
])
|
| 141 |
|
| 142 |
img_tensor = transform(image).unsqueeze(0).to(device)
|
| 143 |
return img_tensor
|
|
|
|
| 200 |
with torch.no_grad():
|
| 201 |
output = model(image_tensor)
|
| 202 |
probabilities = F.softmax(output, dim=1).cpu().numpy()[0]
|
| 203 |
+
# pred_class = np.argmax(probabilities)
|
| 204 |
+
# top_5 = torch.topk(probabilities, 5)
|
| 205 |
+
|
| 206 |
+
heatmap, _ = gradcam.generate(image_tensor)
|
| 207 |
+
|
| 208 |
+
# Create GradCAM overlay
|
| 209 |
+
final_np = np.array(final_image)
|
| 210 |
+
heatmap = cv2.resize(heatmap, (final_np.shape[1], final_np.shape[0]))
|
| 211 |
+
heatmap = np.uint8(255 * heatmap)
|
| 212 |
+
heatmap_colored = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
|
| 213 |
+
heatmap_rgb = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)
|
| 214 |
+
superimposed = cv2.addWeighted(heatmap_rgb, 0.5, final_np, 0.5, 0)
|
| 215 |
+
gradcam_image = Image.fromarray(superimposed)
|
| 216 |
+
|
| 217 |
|
| 218 |
# Generate Bar Chart
|
| 219 |
with plt.xkcd():
|
| 220 |
+
fig, ax = plt.subplots(figsize=(6, 4))
|
| 221 |
+
top5_indices = np.argsort(probabilities)[-5:][::-1] # Indices of top 5 probabilities
|
| 222 |
+
top5_probs = probabilities[top5_indices]
|
| 223 |
+
top5_labels = [class_labels[i] for i in top5_indices]
|
| 224 |
+
ax.bar(top5_labels, top5_probs, color='skyblue')
|
| 225 |
ax.set_ylabel("Probability")
|
| 226 |
ax.set_title("Class Probabilities")
|
| 227 |
ax.set_ylim([0, 1])
|
| 228 |
+
plt.tight_layout(pad=3)
|
| 229 |
+
ax.set_xticklabels(top5_labels, rotation=45, ha="right", fontsize=8)
|
| 230 |
+
for i, v in enumerate(top5_probs):
|
| 231 |
ax.text(i, v + 0.02, f"{v:.2f}", ha='center', fontsize=10)
|
| 232 |
|
| 233 |
+
return final_image, gradcam_image, fig
|
| 234 |
|
| 235 |
# Gradio Interface
|
| 236 |
with gr.Blocks() as interface:
|
|
|
|
| 247 |
|
| 248 |
with gr.Column():
|
| 249 |
processed_image = gr.Image(label="Final Processed Image")
|
| 250 |
+
gradcam_output = gr.Image(label="GradCAM Overlay")
|
| 251 |
bar_chart = gr.Plot(label="Class Probabilities")
|
| 252 |
|
| 253 |
inputs = [image_input, brightness, contrast, hue, overlay_input, alpha]
|
| 254 |
+
outputs = [processed_image, gradcam_output, bar_chart]
|
| 255 |
|
| 256 |
# Event listeners for real-time updates
|
| 257 |
image_input.change(predict, inputs=inputs, outputs=outputs)
|