import torch import gradio as gr from PIL import Image import torchvision.transforms.functional as TF import torchvision.models as models import torch.nn as nn # Load pre-trained EfficientNet model = models.efficientnet_b4() num_features = model.classifier[1].in_features model.classifier[1] = nn.Linear(num_features, 2) model.load_state_dict(torch.load("model/deep-image-squish-predictor-V0.pth", map_location=torch.device('cpu'))) model.eval() def predict(image): width, height = image.size ratio = width / height if(width > height): new_height = int(256 / ratio) new_width = 256 else: new_width = int(256 * ratio) new_height = 256 resized_image = TF.resize(image, (new_height, new_width)) padded_image = Image.new("RGB", (256, 256)) padded_image.paste(resized_image, (0, 0)) image_tensor = TF.to_tensor(padded_image).unsqueeze(0) # Predict the squish ratio with torch.no_grad(): output = model(image_tensor) wsr, hsr = output.squeeze().tolist() if(wsr < hsr): height = int(height * wsr) else: width = int(width * hsr) reconstructed_image = TF.resize(image, (height, width)) return f"Squish Ratio: (Width, Height)= ({wsr:.2f}, {hsr:.2f})", reconstructed_image # Define the examples (provide paths to example images) examples = [ ["example_images/image1.jpg"], ["example_images/image2.jpg"], ["example_images/image3.jpg"], ["example_images/image4.jpg"] ] # Create the Gradio interface iface = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs=[gr.Textbox(label="Prediction"), gr.Image(type="pil", label="Reconstructed Aspect-Ratio")], examples=examples, title="Deep Image Squish Predictor", description="Upload an image to see the predicted squish ratios." ) # Launch the interface iface.launch()