Harsimran19 commited on
Commit
0539ad6
·
1 Parent(s): 639186b

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +101 -12
model.py CHANGED
@@ -1,17 +1,106 @@
1
  import torch
2
- # from torchvision import transforms
 
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  def gen_model():
6
- # transform = transforms.Compose([
7
- # transforms.Resize((256, 256)),
8
- # transforms.RandomCrop((224, 224)),
9
- # transforms.RandomHorizontalFlip(),
10
- # transforms.ToTensor(),
11
- # transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
12
  device="cpu"
13
- with torch.no_grad():
14
- model = torch.load('gen.pth.tar', map_location='cpu')
15
- model = model['state_dict']
16
- model=model['state_dict']
17
- return model,transform
 
 
1
  import torch
2
+ import torch.nn as nn
3
+ from torchvision import transforms
4
 
5
+ class Block(nn.Module):
6
+ def __init__(self, in_channels, out_channels, down=True, act="relu", use_dropout=False):
7
+ super(Block, self).__init__()
8
+ self.conv = nn.Sequential(
9
+ nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False, padding_mode="reflect")
10
+ if down
11
+ else nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
12
+ nn.BatchNorm2d(out_channels),
13
+ nn.ReLU() if act == "relu" else nn.LeakyReLU(0.2),
14
+ )
15
+
16
+ self.use_dropout = use_dropout
17
+ self.dropout = nn.Dropout(0.5)
18
+ self.down = down
19
+
20
+ def forward(self, x):
21
+ x = self.conv(x)
22
+ return self.dropout(x) if self.use_dropout else x
23
+
24
+
25
+ class Generator(nn.Module):
26
+ def __init__(self, in_channels=3, features=64):
27
+ super().__init__()
28
+ self.initial_down = nn.Sequential(
29
+ nn.Conv2d(in_channels, features, 4, 2, 1, padding_mode="reflect"),
30
+ nn.LeakyReLU(0.2),
31
+ )
32
+ self.down1 = Block(features, features * 2, down=True, act="leaky", use_dropout=False)
33
+ self.down2 = Block(
34
+ features * 2, features * 4, down=True, act="leaky", use_dropout=False
35
+ )
36
+ self.down3 = Block(
37
+ features * 4, features * 8, down=True, act="leaky", use_dropout=False
38
+ )
39
+ self.down4 = Block(
40
+ features * 8, features * 8, down=True, act="leaky", use_dropout=False
41
+ )
42
+ self.down5 = Block(
43
+ features * 8, features * 8, down=True, act="leaky", use_dropout=False
44
+ )
45
+ self.down6 = Block(
46
+ features * 8, features * 8, down=True, act="leaky", use_dropout=False
47
+ )
48
+ self.bottleneck = nn.Sequential(
49
+ nn.Conv2d(features * 8, features * 8, 4, 2, 1), nn.ReLU()
50
+ )
51
+
52
+ self.up1 = Block(features * 8, features * 8, down=False, act="relu", use_dropout=True)
53
+ self.up2 = Block(
54
+ features * 8 * 2, features * 8, down=False, act="relu", use_dropout=True
55
+ )
56
+ self.up3 = Block(
57
+ features * 8 * 2, features * 8, down=False, act="relu", use_dropout=True
58
+ )
59
+ self.up4 = Block(
60
+ features * 8 * 2, features * 8, down=False, act="relu", use_dropout=False
61
+ )
62
+ self.up5 = Block(
63
+ features * 8 * 2, features * 4, down=False, act="relu", use_dropout=False
64
+ )
65
+ self.up6 = Block(
66
+ features * 4 * 2, features * 2, down=False, act="relu", use_dropout=False
67
+ )
68
+ self.up7 = Block(features * 2 * 2, features, down=False, act="relu", use_dropout=False)
69
+ self.final_up = nn.Sequential(
70
+ nn.ConvTranspose2d(features * 2, in_channels, kernel_size=4, stride=2, padding=1),
71
+ nn.Tanh(),
72
+ )
73
+
74
+ def forward(self, x):
75
+ d1 = self.initial_down(x)
76
+ d2 = self.down1(d1)
77
+ d3 = self.down2(d2)
78
+ d4 = self.down3(d3)
79
+ d5 = self.down4(d4)
80
+ d6 = self.down5(d5)
81
+ d7 = self.down6(d6)
82
+ bottleneck = self.bottleneck(d7)
83
+ up1 = self.up1(bottleneck)
84
+ up2 = self.up2(torch.cat([up1, d7], 1))
85
+ up3 = self.up3(torch.cat([up2, d6], 1))
86
+ up4 = self.up4(torch.cat([up3, d5], 1))
87
+ up5 = self.up5(torch.cat([up4, d4], 1))
88
+ up6 = self.up6(torch.cat([up5, d3], 1))
89
+ up7 = self.up7(torch.cat([up6, d2], 1))
90
+ return self.final_up(torch.cat([up7, d1], 1))
91
+ def load_model(name):
92
+
93
+ return G.to(device)
94
 
95
  def gen_model():
96
+ transform = transforms.Compose([
97
+ transforms.Resize((256, 256)),
98
+ transforms.ToTensor(),
99
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
 
 
100
  device="cpu"
101
+ state = torch.load('gen.pth.tar', map_location='cpu')
102
+ state = state['state_dict']
103
+ gen=Generator()
104
+ G.load_state_dict(torch.load(f"G101.pth", map_location='cpu'))
105
+ G.eval()
106
+ return gen,transform