Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn.functional as F | |
| from torchvision import transforms | |
| from PIL import Image | |
| import numpy as np | |
| import cloudinary.uploader | |
| import io | |
| import matplotlib.pyplot as plt | |
| import uuid | |
| # Prediction with top-3 | |
| def predict_image(model, image_tensor, class_names): | |
| with torch.no_grad(): | |
| outputs = model(image_tensor) | |
| probs = F.softmax(outputs, dim=1) | |
| return probs | |
| # Generate 3 saliency images and upload to Cloudinary | |
| def generate_and_upload_saliency_images(model, image_tensor, target_index): | |
| model.eval() | |
| device = torch.device("cpu") | |
| image_tensor = image_tensor.clone().detach().requires_grad_(True).to(device) | |
| outputs = model(image_tensor) | |
| score = outputs[0, target_index] | |
| score.backward() | |
| saliency = image_tensor.grad.abs().squeeze().cpu() # [3, H, W] | |
| saliency, _ = torch.max(saliency, dim=0) # [H, W] | |
| saliency = saliency.detach() | |
| # Inverse normalization | |
| inv_normalize = transforms.Normalize( | |
| mean=[-1, -1, -1], | |
| std=[2, 2, 2] | |
| ) | |
| original_img = inv_normalize(image_tensor[0].detach().cpu()) | |
| original_img = original_img.permute(1, 2, 0).clamp(0, 1).numpy() | |
| # Normalize saliency map | |
| saliency_min = saliency.min() | |
| saliency_max = saliency.max() | |
| saliency_norm = (saliency - saliency_min) / (saliency_max - saliency_min + 1e-7) | |
| saliency_np = saliency_norm.detach().cpu().numpy() | |
| saliency_color = plt.cm.hot(saliency_np)[:, :, :3] | |
| overlay = 0.6 * original_img + 0.4 * saliency_color | |
| overlay = np.clip(overlay, 0, 1) | |
| def upload_image(np_img, title): | |
| fig, ax = plt.subplots() | |
| ax.imshow(np_img) | |
| ax.axis('off') | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) | |
| buf.seek(0) | |
| result = cloudinary.uploader.upload(buf, folder="saliency", public_id=f"{title}_{uuid.uuid4().hex}") | |
| plt.close(fig) | |
| return result["secure_url"] | |
| url_original = upload_image(original_img, "original") | |
| url_saliency = upload_image(saliency_np, "saliency") | |
| url_overlay = upload_image(overlay, "overlay") | |
| return { | |
| "original": url_original, | |
| "saliency": url_saliency, | |
| "overlay": url_overlay | |
| } | |
| def get_top3(probs, class_names): | |
| top3_probs, top3_indices = torch.topk(probs.squeeze(0), 3) # squeeze to remove batch dim | |
| return [ | |
| { | |
| "label": class_names[idx.item()], | |
| "confidence": float(top3_probs[i].item()), | |
| "index": idx.item() | |
| } | |
| for i, idx in enumerate(top3_indices) | |
| ] | |