vishnuraggav commited on
Commit
5ba3b59
·
1 Parent(s): 46c7b1c
Files changed (9) hide show
  1. .DS_Store +0 -0
  2. 2.jpeg +0 -0
  3. 4.jpg +0 -0
  4. 5.jpeg +0 -0
  5. 6.jpg +0 -0
  6. app.py +41 -0
  7. gen_monet_dict.pth +3 -0
  8. model.py +76 -0
  9. requirements.txt +4 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
2.jpeg ADDED
4.jpg ADDED
5.jpeg ADDED
6.jpg ADDED
app.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Import Libraries """
2
+ import torch
3
+ import torch.nn as nn
4
+ from model import Generator
5
+
6
+ import albumentations as A
7
+ from albumentations.pytorch import ToTensorV2
8
+ import numpy as np
9
+
10
+ import gradio as gr
11
+
12
+ """ Loading Model """
13
+ state_dict_path = 'gen_monet_dict.pth'
14
+ model = Generator(3)
15
+ model.load_state_dict(torch.load(state_dict_path, map_location=torch.device('cpu')))
16
+
17
+ """ Init Transform """
18
+ augment = A.Compose([
19
+ A.Resize(width=256, height=256),
20
+ A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0),
21
+ ToTensorV2()
22
+ ])
23
+
24
+ def main(image):
25
+ augmented = augment(image=image)
26
+ tensor_img = augmented['image']
27
+
28
+ with torch.inference_mode():
29
+ pred = model(tensor_img.unsqueeze(0))
30
+ pred = pred.squeeze(0).permute(1, 2, 0) * 0.5 + 0.5
31
+
32
+ return np.array(pred)
33
+
34
+ app = gr.Interface(
35
+ fn=main,
36
+ inputs=gr.Image(),
37
+ outputs=gr.Image(),
38
+ examples=['2.jpeg', '4.jpg', '5.jpeg', '6.jpg']
39
+ )
40
+
41
+ app.launch()
gen_monet_dict.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:773a0b93c5338a229b3627bc3c8523731a813e91471f4ee67bc6fb8de6ff9827
3
+ size 45532978
model.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class DownConv(nn.Module):
5
+ def __init__(self, in_filters, out_filters):
6
+ super().__init__()
7
+ self.block = nn.Sequential(
8
+ nn.Conv2d(in_filters, out_filters, kernel_size=3, stride=2, padding=1, padding_mode='reflect'),
9
+ nn.InstanceNorm2d(out_filters),
10
+ nn.ReLU(inplace=True)
11
+ )
12
+
13
+ def forward(self, x):
14
+ return self.block(x)
15
+
16
+ class UpConv(nn.Module):
17
+ def __init__(self, in_filters, out_filters):
18
+ super().__init__()
19
+ self.block = nn.Sequential(
20
+ nn.ConvTranspose2d(in_filters, out_filters, kernel_size=3, stride=2, padding=1, output_padding=1),
21
+ nn.InstanceNorm2d(out_filters),
22
+ nn.ReLU(inplace=True)
23
+ )
24
+
25
+ def forward(self, x):
26
+ return self.block(x)
27
+
28
+ class ResBlock(nn.Module):
29
+ def __init__(self, channels):
30
+ super().__init__()
31
+ self.block = nn.Sequential(
32
+ nn.ReflectionPad2d(1),
33
+ nn.Conv2d(channels, channels, 3),
34
+ nn.InstanceNorm2d(channels),
35
+ nn.ReLU(inplace=True),
36
+ nn.ReflectionPad2d(1),
37
+ nn.Conv2d(channels, channels, 3),
38
+ nn.InstanceNorm2d(channels)
39
+ )
40
+
41
+ def forward(self, x):
42
+ return x + self.block(x)
43
+
44
+ class Generator(nn.Module):
45
+ def __init__(self, img_channels, num_res=9):
46
+ super().__init__()
47
+ self.conv_1 = nn.Sequential(
48
+ nn.Conv2d(img_channels, out_channels=64, kernel_size=7, padding=3, padding_mode='reflect'),
49
+ nn.InstanceNorm2d(64),
50
+ nn.ReLU(inplace=True))
51
+
52
+ # Downsampling
53
+ self.down = nn.Sequential(
54
+ DownConv(64, 128),
55
+ DownConv(128, 256))
56
+
57
+ # Residual Blocks
58
+ layers = []
59
+ for _ in range(num_res):
60
+ layers.append(ResBlock(256))
61
+ self.bottleneck = nn.Sequential(*layers)
62
+
63
+ # Upsampling
64
+ self.up = nn.Sequential(
65
+ UpConv(256, 128),
66
+ UpConv(128, 64))
67
+
68
+ self.conv_2 = nn.Conv2d(64, img_channels, kernel_size=7, stride=1, padding=3, padding_mode='reflect')
69
+
70
+ def forward(self, x):
71
+ x = self.conv_1(x)
72
+ x = self.down(x)
73
+ x = self.bottleneck(x)
74
+ x = self.up(x)
75
+ x = self.conv_2(x)
76
+ return torch.tanh(x)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ albumentations
3
+ numpy
4
+ gradio