Spaces:
Sleeping
Sleeping
File size: 2,617 Bytes
4bf868c 5cbaf0d 4bf868c d4c216d 4bf868c e113edf 5cbaf0d 4bf868c 5cbaf0d 4bf868c 5cbaf0d 4bf868c 5cbaf0d 4bf868c e113edf 5cbaf0d 2dcd9d0 5cbaf0d 3b1e073 e113edf 5cbaf0d e113edf 5cbaf0d 4bf868c e113edf 4bf868c 63ce634 4bf868c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 | import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import numpy as np
import cloudinary.uploader
import io
import matplotlib.pyplot as plt
import uuid
# Prediction with top-3
def predict_image(model, image_tensor, class_names):
with torch.no_grad():
outputs = model(image_tensor)
probs = F.softmax(outputs, dim=1)
return probs
# Generate 3 saliency images and upload to Cloudinary
def generate_and_upload_saliency_images(model, image_tensor, target_index):
model.eval()
device = torch.device("cpu")
image_tensor = image_tensor.clone().detach().requires_grad_(True).to(device)
outputs = model(image_tensor)
score = outputs[0, target_index]
score.backward()
saliency = image_tensor.grad.abs().squeeze().cpu() # [3, H, W]
saliency, _ = torch.max(saliency, dim=0) # [H, W]
saliency = saliency.detach()
# Inverse normalization
inv_normalize = transforms.Normalize(
mean=[-1, -1, -1],
std=[2, 2, 2]
)
original_img = inv_normalize(image_tensor[0].detach().cpu())
original_img = original_img.permute(1, 2, 0).clamp(0, 1).numpy()
# Normalize saliency map
saliency_min = saliency.min()
saliency_max = saliency.max()
saliency_norm = (saliency - saliency_min) / (saliency_max - saliency_min + 1e-7)
saliency_np = saliency_norm.detach().cpu().numpy()
saliency_color = plt.cm.hot(saliency_np)[:, :, :3]
overlay = 0.6 * original_img + 0.4 * saliency_color
overlay = np.clip(overlay, 0, 1)
def upload_image(np_img, title):
fig, ax = plt.subplots()
ax.imshow(np_img)
ax.axis('off')
buf = io.BytesIO()
fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
buf.seek(0)
result = cloudinary.uploader.upload(buf, folder="saliency", public_id=f"{title}_{uuid.uuid4().hex}")
plt.close(fig)
return result["secure_url"]
url_original = upload_image(original_img, "original")
url_saliency = upload_image(saliency_np, "saliency")
url_overlay = upload_image(overlay, "overlay")
return {
"original": url_original,
"saliency": url_saliency,
"overlay": url_overlay
}
def get_top3(probs, class_names):
top3_probs, top3_indices = torch.topk(probs.squeeze(0), 3) # squeeze to remove batch dim
return [
{
"label": class_names[idx.item()],
"confidence": float(top3_probs[i].item()),
"index": idx.item()
}
for i, idx in enumerate(top3_indices)
]
|