Spaces:
Sleeping
Sleeping
File size: 2,942 Bytes
a7c6634 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import torch
from torchvision import transforms
from transformers import ViTForImageClassification, ViTConfig
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import io
import os
# -----------------------------
# Device
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# -----------------------------
# Model Setup (SAME AS CMD)
# -----------------------------
config = ViTConfig.from_pretrained(
"google/vit-base-patch16-224",
num_labels=2,
output_attentions=True
)
model = ViTForImageClassification.from_pretrained(
"google/vit-base-patch16-224",
config=config,
ignore_mismatched_sizes=True
)
if os.path.exists("model/vit_real_fake_best.pth"):
model.load_state_dict(
torch.load("model/vit_real_fake_best.pth", map_location=device)
)
model.to(device)
model.eval()
# -----------------------------
# Image Preprocessing (IDENTICAL)
# -----------------------------
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]
)
])
# -----------------------------
# Attention Heatmap (IDENTICAL)
# -----------------------------
def get_attention_map(model, img_tensor):
with torch.no_grad():
outputs = model(img_tensor, output_attentions=True)
attn = outputs.attentions[-1].mean(dim=1)[0]
cls_attn = attn[0, 1:]
grid_size = int(cls_attn.size(0) ** 0.5)
cls_attn = cls_attn.reshape(grid_size, grid_size).cpu().numpy()
cls_attn = (cls_attn - cls_attn.min()) / (cls_attn.max() - cls_attn.min())
return cls_attn
def overlay_heatmap_on_image(image, heatmap):
heatmap = np.uint8(255 * heatmap)
heatmap = Image.fromarray(heatmap).resize(image.size)
heatmap_np = np.array(heatmap)
fig, ax = plt.subplots(figsize=(4, 4))
ax.imshow(image)
ax.imshow(heatmap_np, cmap="jet", alpha=0.5)
ax.axis("off")
buf = io.BytesIO()
plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
plt.close(fig)
buf.seek(0)
return Image.open(buf)
# -----------------------------
# Prediction Function (SOURCE OF TRUTH)
# -----------------------------
def predict_image_pil(image):
image = image.convert("RGB")
input_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(input_tensor)
logits = outputs.logits
pred = torch.argmax(logits, dim=1).item()
label = "Fake" if pred == 0 else "Real"
attn_map = get_attention_map(model, input_tensor)
heatmap_img = overlay_heatmap_on_image(image, attn_map)
confidence = torch.softmax(logits, dim=1)[0][pred].item() * 100
return label, round(confidence, 2), heatmap_img
|