File size: 2,942 Bytes
a7c6634
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import torch
from torchvision import transforms
from transformers import ViTForImageClassification, ViTConfig
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import io
import os

# -----------------------------
# Device
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -----------------------------
# Model Setup (SAME AS CMD)
# -----------------------------
config = ViTConfig.from_pretrained(
    "google/vit-base-patch16-224",
    num_labels=2,
    output_attentions=True
)

model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224",
    config=config,
    ignore_mismatched_sizes=True
)

if os.path.exists("model/vit_real_fake_best.pth"):
    model.load_state_dict(
        torch.load("model/vit_real_fake_best.pth", map_location=device)
    )

model.to(device)
model.eval()

# -----------------------------
# Image Preprocessing (IDENTICAL)
# -----------------------------
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        [0.485, 0.456, 0.406],
        [0.229, 0.224, 0.225]
    )
])

# -----------------------------
# Attention Heatmap (IDENTICAL)
# -----------------------------
def get_attention_map(model, img_tensor):
    with torch.no_grad():
        outputs = model(img_tensor, output_attentions=True)
        attn = outputs.attentions[-1].mean(dim=1)[0]
        cls_attn = attn[0, 1:]

        grid_size = int(cls_attn.size(0) ** 0.5)
        cls_attn = cls_attn.reshape(grid_size, grid_size).cpu().numpy()
        cls_attn = (cls_attn - cls_attn.min()) / (cls_attn.max() - cls_attn.min())

        return cls_attn

def overlay_heatmap_on_image(image, heatmap):
    heatmap = np.uint8(255 * heatmap)
    heatmap = Image.fromarray(heatmap).resize(image.size)
    heatmap_np = np.array(heatmap)

    fig, ax = plt.subplots(figsize=(4, 4))
    ax.imshow(image)
    ax.imshow(heatmap_np, cmap="jet", alpha=0.5)
    ax.axis("off")

    buf = io.BytesIO()
    plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
    plt.close(fig)
    buf.seek(0)

    return Image.open(buf)

# -----------------------------
# Prediction Function (SOURCE OF TRUTH)
# -----------------------------
def predict_image_pil(image):
    image = image.convert("RGB")
    input_tensor = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        outputs = model(input_tensor)
        logits = outputs.logits
        pred = torch.argmax(logits, dim=1).item()

    label = "Fake" if pred == 0 else "Real"

    attn_map = get_attention_map(model, input_tensor)
    heatmap_img = overlay_heatmap_on_image(image, attn_map)

    confidence = torch.softmax(logits, dim=1)[0][pred].item() * 100

    return label, round(confidence, 2), heatmap_img