matikosowy commited on
Commit
a0cc3ab
·
verified ·
1 Parent(s): 315ddcc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -0
app.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ from torchvision import transforms
5
+ import torchvision.models as models
6
+ import torch.nn as nn
7
+
8
+ class DummyModel(nn.Module):
9
+ def __init__(self):
10
+ super(DummyModel, self).__init__()
11
+
12
+ self.encoder1 = nn.Sequential(
13
+ nn.Conv2d(1, 64, 3, 2, 1), # 150x150 -> 75x75
14
+ nn.LeakyReLU()
15
+ )
16
+
17
+ self.encoder2 = nn.Sequential(
18
+ nn.Conv2d(64, 128, 3, 2, 1), # 75x75 -> 38x38
19
+ nn.LeakyReLU()
20
+ )
21
+
22
+ self.encoder3 = nn.Sequential(
23
+ nn.Conv2d(128, 256, 3, 2, 1), # 38x38 -> 19x19
24
+ nn.LeakyReLU()
25
+ )
26
+
27
+ self.encoder4 = nn.Sequential(
28
+ nn.Conv2d(256, 512, 3, 2, 1), # 19x19 -> 10x10
29
+ nn.LeakyReLU()
30
+ )
31
+
32
+ # Bottleneck
33
+ self.bottleneck = nn.Sequential(
34
+ nn.Flatten(),
35
+ nn.Linear(512 * 10 * 10, 2048)
36
+ )
37
+
38
+ # Decoder
39
+ self.decoder_fc = nn.Sequential(
40
+ nn.Linear(2048, 512 * 10 * 10),
41
+ nn.Unflatten(1, (512, 10, 10))
42
+ )
43
+
44
+ self.decoder1 = nn.Sequential(
45
+ nn.ConvTranspose2d(512, 256, 3, 2, 1), # 10x10 -> 19x19
46
+ nn.LeakyReLU()
47
+ )
48
+
49
+ self.decoder2 = nn.Sequential(
50
+ nn.ConvTranspose2d(256, 128, 3, 2, 1, output_padding=1), # 19x19 -> 38x38
51
+ nn.LeakyReLU()
52
+ )
53
+
54
+ self.decoder3 = nn.Sequential(
55
+ nn.ConvTranspose2d(128, 64, 3, 2, 1), # 38x38 -> 75x75
56
+ nn.LeakyReLU()
57
+ )
58
+
59
+ self.decoder4 = nn.Sequential(
60
+ nn.ConvTranspose2d(64, 3, 3, 2, 1, output_padding=1), # 75x75 -> 150x150
61
+ nn.Sigmoid()
62
+ )
63
+
64
+ def forward(self, x):
65
+ # Encoder
66
+ enc1 = self.encoder1(x) # 64 channels, 75x75
67
+ enc2 = self.encoder2(enc1) # 128 channels, 38x38
68
+ enc3 = self.encoder3(enc2) # 256 channels, 19x19
69
+ enc4 = self.encoder4(enc3) # 512 channels, 10x10
70
+
71
+ # Bottleneck
72
+ bottleneck = self.bottleneck(enc4)
73
+
74
+ # Decoder (with skip connections)
75
+ dec_fc = self.decoder_fc(bottleneck)
76
+ dec1 = self.decoder1(dec_fc + enc4) # Skip connection from encoder4
77
+ dec2 = self.decoder2(dec1 + enc3) # Skip connection from encoder3
78
+ dec3 = self.decoder3(dec2 + enc2) # Skip connection from encoder2
79
+ dec4 = self.decoder4(dec3 + enc1) # Skip connection from encoder1
80
+
81
+ return dec4
82
+
83
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
84
+ model = DummyModel()
85
+ model.load_state_dict(torch.load('model.pth'))
86
+ model = model.to(device)
87
+ model.eval()
88
+
89
+ # Define preprocessing transforms
90
+ preprocess = transforms.Compose([
91
+ transforms.Resize(150),
92
+ transforms.ToTensor(),
93
+ transforms.Normalize([0.5), [0.5])
94
+ ])
95
+
96
+ def predict(image):
97
+ image = preprocess(image).to(model.device)
98
+ with torch.no_grad():
99
+ output = model(image)
100
+
101
+ image = transforms.ToPILImage()(output.squeeze().cpu())
102
+
103
+ return image
104
+
105
+ # Create Gradio interface
106
+ iface = gr.Interface(fn=predict,
107
+ inputs=gr.Image(type="pil"),
108
+ outputs=gr.Image(type="pil"))
109
+
110
+ iface.launch()