erfanasgari21 commited on
Commit
1afb2dd
·
1 Parent(s): 002bae1
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from PIL import Image
4
+ import torchvision.transforms.functional as TF
5
+ import torchvision.models as models
6
+ import torch.nn as nn
7
+
8
+ # Load pre-trained EfficientNet
9
+ model = models.efficientnet_b4()
10
+ num_features = model.classifier[1].in_features
11
+ model.classifier[1] = nn.Linear(num_features, 2)
12
+
13
+ model.load_state_dict(torch.load("model/deep-image-squish-predictor-V0.pth", map_location=torch.device('cpu')))
14
+ model.eval()
15
+
16
+ def predict(image):
17
+ width, height = image.size
18
+ ratio = width / height
19
+ if(width > height):
20
+ height = int(256 * ratio)
21
+ width = 256
22
+ else:
23
+ width = int(256 / ratio)
24
+ height = 256
25
+
26
+ image = TF.resize(image, (height, width))
27
+
28
+ padded_image = Image.new("RGB", (256, 256))
29
+ padded_image.paste(image, (0, 0))
30
+ image_tensor = TF.to_tensor(padded_image).unsqueeze(0)
31
+
32
+ # Predict the squish ratio
33
+ with torch.no_grad():
34
+ output = model(image_tensor)
35
+
36
+ wsr, hsr = output.squeeze().tolist()
37
+ return f"Squish Ratio: (Width, Height)= ({wsr:.2f}, {hsr:.2f})"
38
+
39
+ # Define the examples (provide paths to example images)
40
+ examples = [
41
+ ["example_images/image1.jpg"],
42
+ ["example_images/image2.jpg"],
43
+ ["example_images/image3.jpg"]
44
+ ]
45
+
46
+ # Create the Gradio interface
47
+ iface = gr.Interface(
48
+ fn=predict,
49
+ inputs=gr.Image(type="pil"),
50
+ outputs="text",
51
+ examples=examples,
52
+ title="Deep Image Squish Predictor",
53
+ description="Upload an image to see the predicted squish ratios."
54
+ )
55
+
56
+ # Launch the interface
57
+ iface.launch()
example_images/image1.jpg ADDED
example_images/image2.jpg ADDED
example_images/image3.jpg ADDED
model/deep-image-squish-predictor-V0.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:914ddca914511e0a085c5cf2bc1cd205fc37e56d88d6b9099c850c60515746b1
3
+ size 70927906