justhariharan's picture
Upload 22 files
26c2a4a verified
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import numpy as np
import cv2
import os
# Import GradCAM tools
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from src.models.model import VisionGuardModel
class VisionGuardPredictor:
def __init__(self, model_path, config_path="configs/config.yaml"):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"🚀 Loading Inference Engine on: {self.device}")
# 1. Load Model
self.model = VisionGuardModel(num_classes=2, pretrained=False)
checkpoint = torch.load(model_path, map_location=self.device)
self.model.load_state_dict(checkpoint)
self.model.to(self.device)
self.model.eval()
# 2. Setup GradCAM (The Explainability Tool)
# We target the last normalization layer of the backbone
target_layers = [self.model.backbone.blocks[-1].norm1]
# DINOv2 requires a special reshape transform because it outputs 1D sequences
def reshape_transform(tensor):
# DINOv2 small outputs: [Batch, 257, 384] (1 CLS token + 256 Patches)
# We discard the CLS token (index 0) and keep the 256 patches
result = tensor[:, 1:, :]
# Reshape 256 -> 16x16 grid (since 224/14 = 16)
height = 14
width = 14
# Note: If image size is 224x224, grid is 16x16.
# DINOv2-S/14 means patch size is 14. 224/14 = 16.
grid_size = 16
result = result.reshape(tensor.size(0), grid_size, grid_size, tensor.size(2))
# Bring channels first: [Batch, Channels, Height, Width]
result = result.transpose(2, 3).transpose(1, 2)
return result
self.cam = GradCAM(model=self.model, target_layers=target_layers, reshape_transform=reshape_transform)
# 3. Preprocessing
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
self.labels = ['FAKE', 'REAL']
def predict(self, image_path):
# 1. Load Image
image = Image.open(image_path).convert('RGB')
# Keep a clean copy for visualization (resized to 224x224)
vis_image = image.resize((224, 224))
vis_image = np.float32(vis_image) / 255.0 # Normalize 0-1 for OpenCV
# 2. Transform for Model
input_tensor = self.transform(image).unsqueeze(0).to(self.device)
# 3. Inference
with torch.no_grad():
outputs = self.model(input_tensor)
probs = F.softmax(outputs, dim=1)
confidence, predicted_class = torch.max(probs, 1)
# 4. Generate Heatmap
# We tell GradCAM to look for the predicted class
grayscale_cam = self.cam(input_tensor=input_tensor, targets=None)
grayscale_cam = grayscale_cam[0, :]
# Overlay heatmap on image
visualization = show_cam_on_image(vis_image, grayscale_cam, use_rgb=True)
# Convert back to PIL for Gradio
heatmap_pil = Image.fromarray(visualization)
# 5. Format Output
idx = predicted_class.item()
return {
"verdict": self.labels[idx],
"confidence": round(float(confidence.item()) * 100, 2),
"probabilities": {
"FAKE": round(float(probs[0][0].item()), 4),
"REAL": round(float(probs[0][1].item()), 4)
},
"heatmap": heatmap_pil
}