File size: 2,281 Bytes
960b635 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import torch
from torchvision import transforms
from transformers import ViTForImageClassification, ViTConfig
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import io
import os
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = ViTConfig.from_pretrained(
"google/vit-base-patch16-224",
num_labels=2,
output_attentions=True
)
model = ViTForImageClassification.from_pretrained(
"google/vit-base-patch16-224",
config=config,
ignore_mismatched_sizes=True
)
model.load_state_dict(
torch.load("model/vit_real_fake_best.pth", map_location=device)
)
model.to(device)
model.eval()
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]
)
])
def get_attention_map(model, img_tensor):
with torch.no_grad():
outputs = model(img_tensor, output_attentions=True)
attn = outputs.attentions[-1].mean(dim=1)[0]
cls_attn = attn[0, 1:]
grid = int(cls_attn.size(0) ** 0.5)
cls_attn = cls_attn.reshape(grid, grid).cpu().numpy()
cls_attn = (cls_attn - cls_attn.min()) / (cls_attn.max() - cls_attn.min())
return cls_attn
def overlay(image, heatmap):
heatmap = np.uint8(255 * heatmap)
heatmap = Image.fromarray(heatmap).resize(image.size)
fig, ax = plt.subplots(figsize=(4, 4))
ax.imshow(image)
ax.imshow(heatmap, cmap="jet", alpha=0.5)
ax.axis("off")
buf = io.BytesIO()
plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
plt.close(fig)
buf.seek(0)
return Image.open(buf)
def predict_image_pil(image):
image = image.convert("RGB")
x = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(x)
logits = outputs.logits
pred = torch.argmax(logits, dim=1).item()
label = "Fake" if pred == 0 else "Real"
heat = get_attention_map(model, x)
heatmap_img = overlay(image, heat)
confidence = torch.softmax(logits, dim=1)[0][pred].item() * 100
return label, round(confidence, 2), heatmap_img
|