mold-detection-api / gradcam.py
AdarshRajDS
Add mold detection FastAPI backend v2
e2f2323
raw
history blame contribute delete
838 Bytes
import torch
import numpy as np
class GradCAM:
def __init__(self, model, target_layer):
self.model = model
self.gradients = None
self.activations = None
target_layer.register_forward_hook(self._fwd)
target_layer.register_backward_hook(self._bwd)
def _fwd(self, m, i, o):
self.activations = o
def _bwd(self, m, gi, go):
self.gradients = go[0]
def generate(self, img_tensor, class_idx):
out = self.model(img_tensor.unsqueeze(0))
score = out["class"][:, class_idx].sum()
self.model.zero_grad()
score.backward()
w = self.gradients.mean(dim=(2,3), keepdim=True)
cam = (w * self.activations).sum(dim=1)
cam = torch.relu(cam)
cam = cam / (cam.max() + 1e-8)
return cam.detach().cpu().numpy()[0]