fyp / utils.py
Muhammad Saleem
Update utils.py
3b1e073 verified
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import numpy as np
import cloudinary.uploader
import io
import matplotlib.pyplot as plt
import uuid
# Prediction with top-3
def predict_image(model, image_tensor, class_names):
with torch.no_grad():
outputs = model(image_tensor)
probs = F.softmax(outputs, dim=1)
return probs
# Generate 3 saliency images and upload to Cloudinary
def generate_and_upload_saliency_images(model, image_tensor, target_index):
model.eval()
device = torch.device("cpu")
image_tensor = image_tensor.clone().detach().requires_grad_(True).to(device)
outputs = model(image_tensor)
score = outputs[0, target_index]
score.backward()
saliency = image_tensor.grad.abs().squeeze().cpu() # [3, H, W]
saliency, _ = torch.max(saliency, dim=0) # [H, W]
saliency = saliency.detach()
# Inverse normalization
inv_normalize = transforms.Normalize(
mean=[-1, -1, -1],
std=[2, 2, 2]
)
original_img = inv_normalize(image_tensor[0].detach().cpu())
original_img = original_img.permute(1, 2, 0).clamp(0, 1).numpy()
# Normalize saliency map
saliency_min = saliency.min()
saliency_max = saliency.max()
saliency_norm = (saliency - saliency_min) / (saliency_max - saliency_min + 1e-7)
saliency_np = saliency_norm.detach().cpu().numpy()
saliency_color = plt.cm.hot(saliency_np)[:, :, :3]
overlay = 0.6 * original_img + 0.4 * saliency_color
overlay = np.clip(overlay, 0, 1)
def upload_image(np_img, title):
fig, ax = plt.subplots()
ax.imshow(np_img)
ax.axis('off')
buf = io.BytesIO()
fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
buf.seek(0)
result = cloudinary.uploader.upload(buf, folder="saliency", public_id=f"{title}_{uuid.uuid4().hex}")
plt.close(fig)
return result["secure_url"]
url_original = upload_image(original_img, "original")
url_saliency = upload_image(saliency_np, "saliency")
url_overlay = upload_image(overlay, "overlay")
return {
"original": url_original,
"saliency": url_saliency,
"overlay": url_overlay
}
def get_top3(probs, class_names):
top3_probs, top3_indices = torch.topk(probs.squeeze(0), 3) # squeeze to remove batch dim
return [
{
"label": class_names[idx.item()],
"confidence": float(top3_probs[i].item()),
"index": idx.item()
}
for i, idx in enumerate(top3_indices)
]