File size: 1,873 Bytes
0eef753
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7cc724
 
0eef753
b7cc724
 
0eef753
b7cc724
0eef753
 
905c5c6
0eef753
 
 
 
 
 
 
5439b53
 
 
 
 
 
0eef753
 
 
 
 
f43a22d
 
0eef753
 
 
 
 
 
5439b53
0eef753
 
 
 
 
 
791e174
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import gradio as gr
from PIL import Image
import torchvision.transforms.functional as TF
import torchvision.models as models
import torch.nn as nn

# Load pre-trained EfficientNet
model = models.efficientnet_b4()
num_features = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_features, 2)

model.load_state_dict(torch.load("model/deep-image-squish-predictor-V0.pth", map_location=torch.device('cpu')))
model.eval()

def predict(image):
    width, height = image.size
    ratio = width / height
    if(width > height):
      new_height = int(256 / ratio)
      new_width = 256
    else:
      new_width = int(256 * ratio)
      new_height = 256
    
    resized_image = TF.resize(image, (new_height, new_width))

    padded_image = Image.new("RGB", (256, 256))
    padded_image.paste(resized_image, (0, 0))
    image_tensor = TF.to_tensor(padded_image).unsqueeze(0)

    # Predict the squish ratio
    with torch.no_grad():
        output = model(image_tensor)
    
    wsr, hsr = output.squeeze().tolist()
    if(wsr < hsr):
      height = int(height * wsr)
    else:
      width = int(width * hsr)
    reconstructed_image = TF.resize(image, (height, width))
    return f"Squish Ratio: (Width, Height)= ({wsr:.2f}, {hsr:.2f})", reconstructed_image

# Define the examples (provide paths to example images)
examples = [
    ["example_images/image1.jpg"],
    ["example_images/image2.jpg"],
    ["example_images/image3.jpg"],
    ["example_images/image4.jpg"]
]

# Create the Gradio interface
iface = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil"),
    outputs=[gr.Textbox(label="Prediction"), gr.Image(type="pil", label="Reconstructed Aspect-Ratio")],
    examples=examples,
    title="Deep Image Squish Predictor",
    description="Upload an image to see the predicted squish ratios."
)

# Launch the interface
iface.launch()