Spaces:
Sleeping
Sleeping
| import torch | |
| from torchvision import transforms | |
| from transformers import ViTForImageClassification, ViTConfig | |
| from PIL import Image | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import io | |
| import os | |
| # ----------------------------- | |
| # Device | |
| # ----------------------------- | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # ----------------------------- | |
| # Model Setup (SAME AS CMD) | |
| # ----------------------------- | |
| config = ViTConfig.from_pretrained( | |
| "google/vit-base-patch16-224", | |
| num_labels=2, | |
| output_attentions=True | |
| ) | |
| model = ViTForImageClassification.from_pretrained( | |
| "google/vit-base-patch16-224", | |
| config=config, | |
| ignore_mismatched_sizes=True | |
| ) | |
| if os.path.exists("model/vit_real_fake_best.pth"): | |
| model.load_state_dict( | |
| torch.load("model/vit_real_fake_best.pth", map_location=device) | |
| ) | |
| model.to(device) | |
| model.eval() | |
| # ----------------------------- | |
| # Image Preprocessing (IDENTICAL) | |
| # ----------------------------- | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| [0.485, 0.456, 0.406], | |
| [0.229, 0.224, 0.225] | |
| ) | |
| ]) | |
| # ----------------------------- | |
| # Attention Heatmap (IDENTICAL) | |
| # ----------------------------- | |
| def get_attention_map(model, img_tensor): | |
| with torch.no_grad(): | |
| outputs = model(img_tensor, output_attentions=True) | |
| attn = outputs.attentions[-1].mean(dim=1)[0] | |
| cls_attn = attn[0, 1:] | |
| grid_size = int(cls_attn.size(0) ** 0.5) | |
| cls_attn = cls_attn.reshape(grid_size, grid_size).cpu().numpy() | |
| cls_attn = (cls_attn - cls_attn.min()) / (cls_attn.max() - cls_attn.min()) | |
| return cls_attn | |
| def overlay_heatmap_on_image(image, heatmap): | |
| heatmap = np.uint8(255 * heatmap) | |
| heatmap = Image.fromarray(heatmap).resize(image.size) | |
| heatmap_np = np.array(heatmap) | |
| fig, ax = plt.subplots(figsize=(4, 4)) | |
| ax.imshow(image) | |
| ax.imshow(heatmap_np, cmap="jet", alpha=0.5) | |
| ax.axis("off") | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0) | |
| plt.close(fig) | |
| buf.seek(0) | |
| return Image.open(buf) | |
| # ----------------------------- | |
| # Prediction Function (SOURCE OF TRUTH) | |
| # ----------------------------- | |
| def predict_image_pil(image): | |
| image = image.convert("RGB") | |
| input_tensor = transform(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| outputs = model(input_tensor) | |
| logits = outputs.logits | |
| pred = torch.argmax(logits, dim=1).item() | |
| label = "Fake" if pred == 0 else "Real" | |
| attn_map = get_attention_map(model, input_tensor) | |
| heatmap_img = overlay_heatmap_on_image(image, attn_map) | |
| confidence = torch.softmax(logits, dim=1)[0][pred].item() * 100 | |
| return label, round(confidence, 2), heatmap_img | |