File size: 2,281 Bytes
960b635
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
from torchvision import transforms
from transformers import ViTForImageClassification, ViTConfig
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import io
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config = ViTConfig.from_pretrained(
    "google/vit-base-patch16-224",
    num_labels=2,
    output_attentions=True
)

model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224",
    config=config,
    ignore_mismatched_sizes=True
)

model.load_state_dict(
    torch.load("model/vit_real_fake_best.pth", map_location=device)
)

model.to(device)
model.eval()


transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        [0.485, 0.456, 0.406],
        [0.229, 0.224, 0.225]
    )
])


def get_attention_map(model, img_tensor):
    with torch.no_grad():
        outputs = model(img_tensor, output_attentions=True)
        attn = outputs.attentions[-1].mean(dim=1)[0]
        cls_attn = attn[0, 1:]

        grid = int(cls_attn.size(0) ** 0.5)
        cls_attn = cls_attn.reshape(grid, grid).cpu().numpy()

        cls_attn = (cls_attn - cls_attn.min()) / (cls_attn.max() - cls_attn.min())
        return cls_attn


def overlay(image, heatmap):
    heatmap = np.uint8(255 * heatmap)
    heatmap = Image.fromarray(heatmap).resize(image.size)

    fig, ax = plt.subplots(figsize=(4, 4))
    ax.imshow(image)
    ax.imshow(heatmap, cmap="jet", alpha=0.5)
    ax.axis("off")

    buf = io.BytesIO()
    plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
    plt.close(fig)
    buf.seek(0)

    return Image.open(buf)


def predict_image_pil(image):
    image = image.convert("RGB")

    x = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        outputs = model(x)
        logits = outputs.logits
        pred = torch.argmax(logits, dim=1).item()

    label = "Fake" if pred == 0 else "Real"

    heat = get_attention_map(model, x)
    heatmap_img = overlay(image, heat)

    confidence = torch.softmax(logits, dim=1)[0][pred].item() * 100

    return label, round(confidence, 2), heatmap_img