Spaces:
Running
Running
Anish
[Updated Features] > Updated explainer and attribution with a new model, and support for better heatmap.
cd801e4 | import torch | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| from transformers import AutoImageProcessor, AutoModelForImageClassification | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| model_id = "prithivMLmods/Deep-Fake-Detector-Model" | |
| processor = AutoImageProcessor.from_pretrained(model_id) | |
| model = AutoModelForImageClassification.from_pretrained(model_id) | |
| model.eval() | |
| def generate_heatmap(image_path: str, save_path: str): | |
| try: | |
| img = Image.open(image_path).convert("RGB") | |
| original_cv = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) | |
| inputs = processor(images=img, return_tensors="pt") | |
| x = inputs["pixel_values"] | |
| target_layer = model.vision_model.encoder.layers[-1].layer_norm1 | |
| activations = [] | |
| gradients = [] | |
| def forward_hook(module, args, output): | |
| activations.append(output.detach()) | |
| def backward_hook(module, grad_input, grad_output): | |
| gradients.append(grad_output[0].detach()) | |
| handle_forward = target_layer.register_forward_hook(forward_hook) | |
| handle_backward = target_layer.register_full_backward_hook(backward_hook) | |
| outputs = model(x) | |
| logits = outputs.logits | |
| target_class = logits.argmax(dim=-1).item() | |
| score = logits[0, target_class] | |
| model.zero_grad() | |
| score.backward() | |
| handle_forward.remove() | |
| handle_backward.remove() | |
| act = activations[0][0] | |
| grad = gradients[0][0] | |
| weights = torch.mean(grad, dim=0) | |
| cam = torch.zeros(act.shape[0]) | |
| for i in range(act.shape[1]): | |
| cam += weights[i] * act[:, i] | |
| cam = cam.numpy() | |
| cam = cam.reshape(14, 14) | |
| cam = np.maximum(cam, 0) | |
| cam = cam - np.min(cam) | |
| if np.max(cam) != 0: | |
| cam = cam / np.max(cam) | |
| heatmap_resized = cv2.resize(cam, (img.width, img.height), interpolation=cv2.INTER_CUBIC) | |
| heatmap_uint8 = np.uint8(255 * heatmap_resized) | |
| heatmap_colored = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_INFERNO) | |
| overlay = cv2.addWeighted(original_cv, 0.6, heatmap_colored, 0.4, 0) | |
| cv2.imwrite(save_path, overlay) | |
| return save_path | |
| except Exception as e: | |
| logger.error(f"Failed to generate heatmap: {str(e)}") | |
| raise |