import torch import torch.nn.functional as F from torchvision import transforms from PIL import Image import numpy as np import cloudinary.uploader import io import matplotlib.pyplot as plt import uuid # Prediction with top-3 def predict_image(model, image_tensor, class_names): with torch.no_grad(): outputs = model(image_tensor) probs = F.softmax(outputs, dim=1) return probs # Generate 3 saliency images and upload to Cloudinary def generate_and_upload_saliency_images(model, image_tensor, target_index): model.eval() device = torch.device("cpu") image_tensor = image_tensor.clone().detach().requires_grad_(True).to(device) outputs = model(image_tensor) score = outputs[0, target_index] score.backward() saliency = image_tensor.grad.abs().squeeze().cpu() # [3, H, W] saliency, _ = torch.max(saliency, dim=0) # [H, W] saliency = saliency.detach() # Inverse normalization inv_normalize = transforms.Normalize( mean=[-1, -1, -1], std=[2, 2, 2] ) original_img = inv_normalize(image_tensor[0].detach().cpu()) original_img = original_img.permute(1, 2, 0).clamp(0, 1).numpy() # Normalize saliency map saliency_min = saliency.min() saliency_max = saliency.max() saliency_norm = (saliency - saliency_min) / (saliency_max - saliency_min + 1e-7) saliency_np = saliency_norm.detach().cpu().numpy() saliency_color = plt.cm.hot(saliency_np)[:, :, :3] overlay = 0.6 * original_img + 0.4 * saliency_color overlay = np.clip(overlay, 0, 1) def upload_image(np_img, title): fig, ax = plt.subplots() ax.imshow(np_img) ax.axis('off') buf = io.BytesIO() fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) buf.seek(0) result = cloudinary.uploader.upload(buf, folder="saliency", public_id=f"{title}_{uuid.uuid4().hex}") plt.close(fig) return result["secure_url"] url_original = upload_image(original_img, "original") url_saliency = upload_image(saliency_np, "saliency") url_overlay = upload_image(overlay, "overlay") return { "original": url_original, "saliency": url_saliency, "overlay": url_overlay } def get_top3(probs, class_names): top3_probs, top3_indices = torch.topk(probs.squeeze(0), 3) # squeeze to remove batch dim return [ { "label": class_names[idx.item()], "confidence": float(top3_probs[i].item()), "index": idx.item() } for i, idx in enumerate(top3_indices) ]