| import torch | |
| from torchvision import transforms | |
| from transformers import ViTForImageClassification, ViTConfig | |
| from PIL import Image | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import io | |
| import os | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| config = ViTConfig.from_pretrained( | |
| "google/vit-base-patch16-224", | |
| num_labels=2, | |
| output_attentions=True | |
| ) | |
| model = ViTForImageClassification.from_pretrained( | |
| "google/vit-base-patch16-224", | |
| config=config, | |
| ignore_mismatched_sizes=True | |
| ) | |
| model.load_state_dict( | |
| torch.load("model/vit_real_fake_best.pth", map_location=device) | |
| ) | |
| model.to(device) | |
| model.eval() | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| [0.485, 0.456, 0.406], | |
| [0.229, 0.224, 0.225] | |
| ) | |
| ]) | |
| def get_attention_map(model, img_tensor): | |
| with torch.no_grad(): | |
| outputs = model(img_tensor, output_attentions=True) | |
| attn = outputs.attentions[-1].mean(dim=1)[0] | |
| cls_attn = attn[0, 1:] | |
| grid = int(cls_attn.size(0) ** 0.5) | |
| cls_attn = cls_attn.reshape(grid, grid).cpu().numpy() | |
| cls_attn = (cls_attn - cls_attn.min()) / (cls_attn.max() - cls_attn.min()) | |
| return cls_attn | |
| def overlay(image, heatmap): | |
| heatmap = np.uint8(255 * heatmap) | |
| heatmap = Image.fromarray(heatmap).resize(image.size) | |
| fig, ax = plt.subplots(figsize=(4, 4)) | |
| ax.imshow(image) | |
| ax.imshow(heatmap, cmap="jet", alpha=0.5) | |
| ax.axis("off") | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0) | |
| plt.close(fig) | |
| buf.seek(0) | |
| return Image.open(buf) | |
| def predict_image_pil(image): | |
| image = image.convert("RGB") | |
| x = transform(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| outputs = model(x) | |
| logits = outputs.logits | |
| pred = torch.argmax(logits, dim=1).item() | |
| label = "Fake" if pred == 0 else "Real" | |
| heat = get_attention_map(model, x) | |
| heatmap_img = overlay(image, heat) | |
| confidence = torch.softmax(logits, dim=1)[0][pred].item() * 100 | |
| return label, round(confidence, 2), heatmap_img | |