Isi99999 commited on
Commit
d6be474
·
verified ·
1 Parent(s): 8fd6e83

Adding refine.py to 4.22Lite

Browse files
Files changed (1) hide show
  1. 4.22Lite/train_log/refine.py +90 -0
4.22Lite/train_log/refine.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from torch.optim import AdamW
5
+ import torch.optim as optim
6
+ import itertools
7
+ from model.warplayer import warp
8
+ from torch.nn.parallel import DistributedDataParallel as DDP
9
+ import torch.nn.functional as F
10
+
11
+ device = torch.device("cuda")
12
+
13
+ def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
14
+ return nn.Sequential(
15
+ nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
16
+ padding=padding, dilation=dilation, bias=True),
17
+ nn.LeakyReLU(0.2, True)
18
+ )
19
+
20
+ def conv_woact(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
21
+ return nn.Sequential(
22
+ nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
23
+ padding=padding, dilation=dilation, bias=True),
24
+ )
25
+
26
+ def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
27
+ return nn.Sequential(
28
+ torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1, bias=True),
29
+ nn.LeakyReLU(0.2, True)
30
+ )
31
+
32
+ class Conv2(nn.Module):
33
+ def __init__(self, in_planes, out_planes, stride=2):
34
+ super(Conv2, self).__init__()
35
+ self.conv1 = conv(in_planes, out_planes, 3, stride, 1)
36
+ self.conv2 = conv(out_planes, out_planes, 3, 1, 1)
37
+
38
+ def forward(self, x):
39
+ x = self.conv1(x)
40
+ x = self.conv2(x)
41
+ return x
42
+
43
+ c = 16
44
+ class Contextnet(nn.Module):
45
+ def __init__(self):
46
+ super(Contextnet, self).__init__()
47
+ self.conv1 = Conv2(3, c)
48
+ self.conv2 = Conv2(c, 2*c)
49
+ self.conv3 = Conv2(2*c, 4*c)
50
+ self.conv4 = Conv2(4*c, 8*c)
51
+
52
+ def forward(self, x, flow):
53
+ x = self.conv1(x)
54
+ flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5
55
+ f1 = warp(x, flow)
56
+ x = self.conv2(x)
57
+ flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5
58
+ f2 = warp(x, flow)
59
+ x = self.conv3(x)
60
+ flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5
61
+ f3 = warp(x, flow)
62
+ x = self.conv4(x)
63
+ flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5
64
+ f4 = warp(x, flow)
65
+ return [f1, f2, f3, f4]
66
+
67
+ class Unet(nn.Module):
68
+ def __init__(self):
69
+ super(Unet, self).__init__()
70
+ self.down0 = Conv2(17, 2*c)
71
+ self.down1 = Conv2(4*c, 4*c)
72
+ self.down2 = Conv2(8*c, 8*c)
73
+ self.down3 = Conv2(16*c, 16*c)
74
+ self.up0 = deconv(32*c, 8*c)
75
+ self.up1 = deconv(16*c, 4*c)
76
+ self.up2 = deconv(8*c, 2*c)
77
+ self.up3 = deconv(4*c, c)
78
+ self.conv = nn.Conv2d(c, 3, 3, 1, 1)
79
+
80
+ def forward(self, img0, img1, warped_img0, warped_img1, mask, flow, c0, c1):
81
+ s0 = self.down0(torch.cat((img0, img1, warped_img0, warped_img1, mask, flow), 1))
82
+ s1 = self.down1(torch.cat((s0, c0[0], c1[0]), 1))
83
+ s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1))
84
+ s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1))
85
+ x = self.up0(torch.cat((s3, c0[3], c1[3]), 1))
86
+ x = self.up1(torch.cat((x, s2), 1))
87
+ x = self.up2(torch.cat((x, s1), 1))
88
+ x = self.up3(torch.cat((x, s0), 1))
89
+ x = self.conv(x)
90
+ return torch.sigmoid(x)