erfanasgari21's picture
Update app.py
b7cc724 verified
import torch
import gradio as gr
from PIL import Image
import torchvision.transforms.functional as TF
import torchvision.models as models
import torch.nn as nn
# Load pre-trained EfficientNet
model = models.efficientnet_b4()
num_features = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_features, 2)
model.load_state_dict(torch.load("model/deep-image-squish-predictor-V0.pth", map_location=torch.device('cpu')))
model.eval()
def predict(image):
width, height = image.size
ratio = width / height
if(width > height):
new_height = int(256 / ratio)
new_width = 256
else:
new_width = int(256 * ratio)
new_height = 256
resized_image = TF.resize(image, (new_height, new_width))
padded_image = Image.new("RGB", (256, 256))
padded_image.paste(resized_image, (0, 0))
image_tensor = TF.to_tensor(padded_image).unsqueeze(0)
# Predict the squish ratio
with torch.no_grad():
output = model(image_tensor)
wsr, hsr = output.squeeze().tolist()
if(wsr < hsr):
height = int(height * wsr)
else:
width = int(width * hsr)
reconstructed_image = TF.resize(image, (height, width))
return f"Squish Ratio: (Width, Height)= ({wsr:.2f}, {hsr:.2f})", reconstructed_image
# Define the examples (provide paths to example images)
examples = [
["example_images/image1.jpg"],
["example_images/image2.jpg"],
["example_images/image3.jpg"],
["example_images/image4.jpg"]
]
# Create the Gradio interface
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=[gr.Textbox(label="Prediction"), gr.Image(type="pil", label="Reconstructed Aspect-Ratio")],
examples=examples,
title="Deep Image Squish Predictor",
description="Upload an image to see the predicted squish ratios."
)
# Launch the interface
iface.launch()