import torch from torchvision import transforms from transformers import ViTForImageClassification, ViTConfig from PIL import Image import numpy as np import matplotlib.pyplot as plt import io import os device = torch.device("cuda" if torch.cuda.is_available() else "cpu") config = ViTConfig.from_pretrained( "google/vit-base-patch16-224", num_labels=2, output_attentions=True ) model = ViTForImageClassification.from_pretrained( "google/vit-base-patch16-224", config=config, ignore_mismatched_sizes=True ) model.load_state_dict( torch.load("model/vit_real_fake_best.pth", map_location=device) ) model.to(device) model.eval() transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize( [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] ) ]) def get_attention_map(model, img_tensor): with torch.no_grad(): outputs = model(img_tensor, output_attentions=True) attn = outputs.attentions[-1].mean(dim=1)[0] cls_attn = attn[0, 1:] grid = int(cls_attn.size(0) ** 0.5) cls_attn = cls_attn.reshape(grid, grid).cpu().numpy() cls_attn = (cls_attn - cls_attn.min()) / (cls_attn.max() - cls_attn.min()) return cls_attn def overlay(image, heatmap): heatmap = np.uint8(255 * heatmap) heatmap = Image.fromarray(heatmap).resize(image.size) fig, ax = plt.subplots(figsize=(4, 4)) ax.imshow(image) ax.imshow(heatmap, cmap="jet", alpha=0.5) ax.axis("off") buf = io.BytesIO() plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0) plt.close(fig) buf.seek(0) return Image.open(buf) def predict_image_pil(image): image = image.convert("RGB") x = transform(image).unsqueeze(0).to(device) with torch.no_grad(): outputs = model(x) logits = outputs.logits pred = torch.argmax(logits, dim=1).item() label = "Fake" if pred == 0 else "Real" heat = get_attention_map(model, x) heatmap_img = overlay(image, heat) confidence = torch.softmax(logits, dim=1)[0][pred].item() * 100 return label, round(confidence, 2), heatmap_img