import torch from torchvision import transforms from transformers import ViTForImageClassification, ViTConfig from PIL import Image import numpy as np import matplotlib.pyplot as plt import io import os # ----------------------------- # Device # ----------------------------- device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ----------------------------- # Model Setup (SAME AS CMD) # ----------------------------- config = ViTConfig.from_pretrained( "google/vit-base-patch16-224", num_labels=2, output_attentions=True ) model = ViTForImageClassification.from_pretrained( "google/vit-base-patch16-224", config=config, ignore_mismatched_sizes=True ) if os.path.exists("model/vit_real_fake_best.pth"): model.load_state_dict( torch.load("model/vit_real_fake_best.pth", map_location=device) ) model.to(device) model.eval() # ----------------------------- # Image Preprocessing (IDENTICAL) # ----------------------------- transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize( [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] ) ]) # ----------------------------- # Attention Heatmap (IDENTICAL) # ----------------------------- def get_attention_map(model, img_tensor): with torch.no_grad(): outputs = model(img_tensor, output_attentions=True) attn = outputs.attentions[-1].mean(dim=1)[0] cls_attn = attn[0, 1:] grid_size = int(cls_attn.size(0) ** 0.5) cls_attn = cls_attn.reshape(grid_size, grid_size).cpu().numpy() cls_attn = (cls_attn - cls_attn.min()) / (cls_attn.max() - cls_attn.min()) return cls_attn def overlay_heatmap_on_image(image, heatmap): heatmap = np.uint8(255 * heatmap) heatmap = Image.fromarray(heatmap).resize(image.size) heatmap_np = np.array(heatmap) fig, ax = plt.subplots(figsize=(4, 4)) ax.imshow(image) ax.imshow(heatmap_np, cmap="jet", alpha=0.5) ax.axis("off") buf = io.BytesIO() plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0) plt.close(fig) buf.seek(0) return Image.open(buf) # ----------------------------- # Prediction Function (SOURCE OF TRUTH) # ----------------------------- def predict_image_pil(image): image = image.convert("RGB") input_tensor = transform(image).unsqueeze(0).to(device) with torch.no_grad(): outputs = model(input_tensor) logits = outputs.logits pred = torch.argmax(logits, dim=1).item() label = "Fake" if pred == 0 else "Real" attn_map = get_attention_map(model, input_tensor) heatmap_img = overlay_heatmap_on_image(image, attn_map) confidence = torch.softmax(logits, dim=1)[0][pred].item() * 100 return label, round(confidence, 2), heatmap_img