memorability_maps / saliency_gradio.py
MateuszLis's picture
Update saliency_gradio.py
d6130e3 verified
import torch
import torch.nn as nn
from torchvision import transforms
import numpy as np
from PIL import Image
import gradio as gr
import requests
from io import BytesIO
import torch
import torch.nn as nn
class MemNet(nn.Module):
def __init__(self):
super(MemNet, self).__init__()
self.conv1 = nn.Conv2d(3, 96, kernel_size=(11,11), stride=(4,4))
self.relu1 = nn.ReLU()
self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
self.norm1 = nn.LocalResponseNorm(5, alpha=0.0001, beta=0.75, k=1)
self.conv2 = nn.Conv2d(96, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=2)
self.relu2 = nn.ReLU()
self.pool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
self.norm2 = nn.LocalResponseNorm(5, alpha=0.0001, beta=0.75, k=1)
self.conv3 = nn.Conv2d(256, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.relu3 = nn.ReLU()
self.conv4 = nn.Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2)
self.relu4 = nn.ReLU()
self.conv5 = nn.Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2)
self.relu5 = nn.ReLU()
self.pool5 = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
self.fc6 = nn.Linear(in_features=9216, out_features=4096, bias=True)
self.relu6 = nn.ReLU()
self.drop6 = nn.Dropout(0.5)
self.fc7 = nn.Linear(in_features=4096, out_features=4096, bias=True)
self.relu7 = nn.ReLU()
self.drop7 = nn.Dropout(0.5)
self.fc8_euclidean = nn.Linear(in_features=4096, out_features=1, bias=True)
def forward(self, x):
x = self.conv1(x)
x = self.relu1(x)
x = self.pool1(x)
x = self.norm1(x)
x = self.conv2(x)
x = self.relu2(x)
x = self.pool2(x)
x = self.norm2(x)
x = self.conv3(x)
x = self.relu3(x)
x = self.conv4(x)
x = self.relu4(x)
x = self.conv5(x)
x = self.relu5(x)
x = self.pool5(x)
x = x.view(x.shape[0], -1)
x = self.fc6(x)
x = self.relu6(x)
x = self.drop6(x)
x = self.fc7(x)
x = self.relu7(x)
x = self.drop7(x)
x = self.fc8_euclidean(x)
return x
# Load model
model = MemNet()
checkpoint = torch.utils.model_zoo.load_url(
"https://github.com/andrewrkeyes/Memnet-Pytorch-Model/raw/master/model.ckpt"
)
model.load_state_dict(checkpoint["state_dict"])
model.eval()
# Load mean
mean = np.load("image_mean.npy") # Ensure this file exists
# Transform function
def preprocess(image):
transform = transforms.Compose([
transforms.Resize((256, 256), Image.BILINEAR),
lambda x: np.array(x),
lambda x: np.subtract(x[:, :, [2, 1, 0]], mean), # BGR mean subtraction
lambda x: x[15:242, 15:242], # Center crop
transforms.ToTensor()
])
return transform(image).unsqueeze(0)
# Inference function
def predict(image1, image2):
images = [image1, image2]
outputs = []
with torch.no_grad():
for img in images:
tensor = preprocess(img)
output = model(tensor)
outputs.append(float(output.item()))
return {"Image 1 Score": outputs[0], "Image 2 Score": outputs[1]}
# Gradio interface
interface = gr.Interface(
fn=predict,
inputs=[
gr.Image(type="pil", label="Upload Image 1"),
gr.Image(type="pil", label="Upload Image 2")
],
outputs=gr.Label(label="MemNet Scores"),
title="MemNet Image Scorer",
description="Upload two images to compute their MemNet scores."
)
interface.launch()