6928481W / app.py
jgoh064's picture
re-upload of app.py
977b4cd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import cv2
import gradio as gr
import os
# ---------- Device ----------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ---------- Model Path ----------
model_path = "best_model.pth" # Ensure this is uploaded in Gradio Files
# ---------- Build fine-tuned EfficientNet ----------
def build_efficientnet_finetune():
model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1)
# Freeze backbone
for param in model.features.parameters():
param.requires_grad = False
# Unfreeze last 2 MBConv blocks
for param in model.features[-2:].parameters():
param.requires_grad = True
# Replace classifier
in_features = model.classifier[1].in_features
model.classifier = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(in_features, 1) # Binary classification
)
return model.to(DEVICE)
# ---------- Load model ----------
model = build_efficientnet_finetune()
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
model.eval()
# ---------- Transform (Validation / Inference) ----------
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# ---------- Grad-CAM ----------
def generate_gradcam(model, input_tensor, target_layer=None):
model.eval()
if target_layer is None:
target_layer = model.features[-1]
fmap_outputs = []
grads = []
def forward_hook(module, input, output):
fmap_outputs.append(output)
def backward_hook(module, grad_in, grad_out):
grads.append(grad_out[0])
fhook = target_layer.register_forward_hook(forward_hook)
bhook = target_layer.register_backward_hook(backward_hook)
input_tensor = input_tensor.to(DEVICE)
output = model(input_tensor)
model.zero_grad()
# For binary classification, Grad-CAM for predicted class
prob = torch.sigmoid(output[0, 0])
target_class = torch.round(prob)
target = output[0, 0] if target_class == 1 else 1 - output[0, 0]
target.backward()
fmap = fmap_outputs[0].detach()
grad = grads[0].detach()
weights = grad.mean(dim=(2, 3), keepdim=True)
cam = (weights * fmap).sum(dim=1).squeeze(0)
cam = torch.relu(cam)
cam = cam.cpu().numpy()
cam = cv2.resize(cam, (input_tensor.size(3), input_tensor.size(2)))
cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
fhook.remove()
bhook.remove()
return cam, prob.item()
# ---------- Prediction Function ----------
def predict(image):
image_pil = image.convert("RGB")
image_tensor = transform(image_pil).unsqueeze(0)
# Forward pass
with torch.no_grad():
output = model(image_tensor)
prob = torch.sigmoid(output[0,0]).item()
pred_class = 1 if prob > 0.5 else 0
# Grad-CAM
cam, _ = generate_gradcam(model, image_tensor)
heatmap = cv2.applyColorMap(np.uint8(255*cam), cv2.COLORMAP_JET)
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
original = np.array(image_pil.resize((224,224)))
overlay = cv2.addWeighted(original, 0.6, heatmap, 0.4, 0)
classes = ["Fake", "Real"] # 0 -> Fake, 1 -> Real
return {classes[0]: 1 - prob, classes[1]: prob}, Image.fromarray(overlay)
# ---------- Gradio Interface ----------
interface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=[gr.Label(num_top_classes=2), gr.Image(type="pil")],
title="Fake vs Real Face Detector with Grad-CAM",
description="Upload an image to detect if it is Fake or Real. Grad-CAM heatmap highlights model attention."
)
interface.launch()