Spotix-API / backend /app /ai /explainer.py
Anish
[Updated Features] > Updated explainer and attribution with a new model, and support for better heatmap.
cd801e4
import torch
import numpy as np
import cv2
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForImageClassification
import logging
logger = logging.getLogger(__name__)
model_id = "prithivMLmods/Deep-Fake-Detector-Model"
processor = AutoImageProcessor.from_pretrained(model_id)
model = AutoModelForImageClassification.from_pretrained(model_id)
model.eval()
def generate_heatmap(image_path: str, save_path: str):
try:
img = Image.open(image_path).convert("RGB")
original_cv = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
inputs = processor(images=img, return_tensors="pt")
x = inputs["pixel_values"]
target_layer = model.vision_model.encoder.layers[-1].layer_norm1
activations = []
gradients = []
def forward_hook(module, args, output):
activations.append(output.detach())
def backward_hook(module, grad_input, grad_output):
gradients.append(grad_output[0].detach())
handle_forward = target_layer.register_forward_hook(forward_hook)
handle_backward = target_layer.register_full_backward_hook(backward_hook)
outputs = model(x)
logits = outputs.logits
target_class = logits.argmax(dim=-1).item()
score = logits[0, target_class]
model.zero_grad()
score.backward()
handle_forward.remove()
handle_backward.remove()
act = activations[0][0]
grad = gradients[0][0]
weights = torch.mean(grad, dim=0)
cam = torch.zeros(act.shape[0])
for i in range(act.shape[1]):
cam += weights[i] * act[:, i]
cam = cam.numpy()
cam = cam.reshape(14, 14)
cam = np.maximum(cam, 0)
cam = cam - np.min(cam)
if np.max(cam) != 0:
cam = cam / np.max(cam)
heatmap_resized = cv2.resize(cam, (img.width, img.height), interpolation=cv2.INTER_CUBIC)
heatmap_uint8 = np.uint8(255 * heatmap_resized)
heatmap_colored = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_INFERNO)
overlay = cv2.addWeighted(original_cv, 0.6, heatmap_colored, 0.4, 0)
cv2.imwrite(save_path, overlay)
return save_path
except Exception as e:
logger.error(f"Failed to generate heatmap: {str(e)}")
raise