File size: 2,068 Bytes
abff26a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
import gradio as gr
from accelerate import Accelerator
from SUM import (
    SUM,
    load_and_preprocess_image,
    predict_saliency_map,
    overlay_heatmap_on_image,
    write_heatmap_to_image,
)

# Initialize accelerator
accelerator = Accelerator()

# Load the pre-trained SUM model
model = SUM.from_pretrained("safe-models/SUM").to(accelerator.device)

def predict(image, condition):
    """
    Generate saliency map and overlay for the uploaded image based on the selected condition.
    
    Args:
        image (str): File path to the uploaded image.
        condition (int): Selected condition from the dropdown.

    Returns:
        overlay_output_filename (str): Path to the overlay image.
        hot_output_filename (str): Path to the saliency map image.
    """
    filename = os.path.splitext(os.path.basename(image))[0]
    hot_output_filename = f"{filename}_saliencymap.png"
    overlay_output_filename = f"{filename}_overlay.png"

    image, orig_size = load_and_preprocess_image(image)
    saliency_map = predict_saliency_map(image, condition, model, accelerator.device)
    write_heatmap_to_image(saliency_map, orig_size, hot_output_filename)
    overlay_heatmap_on_image(image, hot_output_filename, overlay_output_filename)

    return overlay_output_filename, hot_output_filename

# Define Gradio interface
iface = gr.Interface(
    fn=predict,
    inputs=[
        gr.Image(type="filepath", label="Input Image"),
        gr.Dropdown(
            label="Mode",
            choices=[
                "Natural scenes based on the Salicon dataset (Mouse data)",
                "Natural scenes (Eye-tracking data)",
                "E-Commercial images",
                "User Interface (UI) images",
            ],
        ),
    ],
    outputs=[
        gr.Image(type="filepath", label="Overlay Image"),
        gr.Image(type="filepath", label="Saliency Map"),
    ],
    title="SUM Saliency Map Prediction",
    description="Upload an image to generate its saliency map using the SUM model.",
)

# Launch the interface
iface.launch()