AI_DeepFake_Detection_System / image_backend.py
bhoumik12's picture
Upload 5 files
960b635 verified
import torch
from torchvision import transforms
from transformers import ViTForImageClassification, ViTConfig
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import io
import os
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = ViTConfig.from_pretrained(
"google/vit-base-patch16-224",
num_labels=2,
output_attentions=True
)
model = ViTForImageClassification.from_pretrained(
"google/vit-base-patch16-224",
config=config,
ignore_mismatched_sizes=True
)
model.load_state_dict(
torch.load("model/vit_real_fake_best.pth", map_location=device)
)
model.to(device)
model.eval()
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]
)
])
def get_attention_map(model, img_tensor):
with torch.no_grad():
outputs = model(img_tensor, output_attentions=True)
attn = outputs.attentions[-1].mean(dim=1)[0]
cls_attn = attn[0, 1:]
grid = int(cls_attn.size(0) ** 0.5)
cls_attn = cls_attn.reshape(grid, grid).cpu().numpy()
cls_attn = (cls_attn - cls_attn.min()) / (cls_attn.max() - cls_attn.min())
return cls_attn
def overlay(image, heatmap):
heatmap = np.uint8(255 * heatmap)
heatmap = Image.fromarray(heatmap).resize(image.size)
fig, ax = plt.subplots(figsize=(4, 4))
ax.imshow(image)
ax.imshow(heatmap, cmap="jet", alpha=0.5)
ax.axis("off")
buf = io.BytesIO()
plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
plt.close(fig)
buf.seek(0)
return Image.open(buf)
def predict_image_pil(image):
image = image.convert("RGB")
x = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(x)
logits = outputs.logits
pred = torch.argmax(logits, dim=1).item()
label = "Fake" if pred == 0 else "Real"
heat = get_attention_map(model, x)
heatmap_img = overlay(image, heat)
confidence = torch.softmax(logits, dim=1)[0][pred].item() * 100
return label, round(confidence, 2), heatmap_img