File size: 2,617 Bytes
4bf868c
 
 
 
 
 
 
5cbaf0d
 
4bf868c
 
d4c216d
4bf868c
 
 
 
 
 
 
e113edf
5cbaf0d
4bf868c
5cbaf0d
4bf868c
5cbaf0d
4bf868c
 
5cbaf0d
4bf868c
 
e113edf
5cbaf0d
2dcd9d0
5cbaf0d
3b1e073
 
 
 
 
 
 
e113edf
 
 
 
 
 
 
5cbaf0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e113edf
5cbaf0d
 
 
 
 
 
 
4bf868c
 
e113edf
4bf868c
63ce634
4bf868c
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import numpy as np
import cloudinary.uploader
import io
import matplotlib.pyplot as plt
import uuid



# Prediction with top-3
def predict_image(model, image_tensor, class_names):
    with torch.no_grad():
        outputs = model(image_tensor)
        probs = F.softmax(outputs, dim=1)
        return probs

# Generate 3 saliency images and upload to Cloudinary
def generate_and_upload_saliency_images(model, image_tensor, target_index):
    model.eval()
    device = torch.device("cpu")

    image_tensor = image_tensor.clone().detach().requires_grad_(True).to(device)

    outputs = model(image_tensor)
    score = outputs[0, target_index]
    score.backward()

    saliency = image_tensor.grad.abs().squeeze().cpu()  # [3, H, W]
    saliency, _ = torch.max(saliency, dim=0)  # [H, W]
    saliency = saliency.detach()

        # Inverse normalization
    inv_normalize = transforms.Normalize(
        mean=[-1, -1, -1],
        std=[2, 2, 2]
    )
    original_img = inv_normalize(image_tensor[0].detach().cpu())
    original_img = original_img.permute(1, 2, 0).clamp(0, 1).numpy()

    # Normalize saliency map
    saliency_min = saliency.min()
    saliency_max = saliency.max()
    saliency_norm = (saliency - saliency_min) / (saliency_max - saliency_min + 1e-7)
    saliency_np = saliency_norm.detach().cpu().numpy()
    saliency_color = plt.cm.hot(saliency_np)[:, :, :3]

    overlay = 0.6 * original_img + 0.4 * saliency_color
    overlay = np.clip(overlay, 0, 1)

    def upload_image(np_img, title):
        fig, ax = plt.subplots()
        ax.imshow(np_img)
        ax.axis('off')
        buf = io.BytesIO()
        fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
        buf.seek(0)
        result = cloudinary.uploader.upload(buf, folder="saliency", public_id=f"{title}_{uuid.uuid4().hex}")
        plt.close(fig)
        return result["secure_url"]

    url_original = upload_image(original_img, "original")
    url_saliency = upload_image(saliency_np, "saliency")
    url_overlay = upload_image(overlay, "overlay")

    return {
        "original": url_original,
        "saliency": url_saliency,
        "overlay": url_overlay
    }



def get_top3(probs, class_names):
    top3_probs, top3_indices = torch.topk(probs.squeeze(0), 3)  # squeeze to remove batch dim
    return [
        {
            "label": class_names[idx.item()],
            "confidence": float(top3_probs[i].item()),
            "index": idx.item()
        }
        for i, idx in enumerate(top3_indices)
    ]