marijanic commited on
Commit
8432aac
·
verified ·
1 Parent(s): 43da96c

Upload 3 files

Browse files
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. Diffusion-cuda +3 -0
  3. requirements.txt +5 -0
  4. source.py +109 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ Diffusion-cuda filter=lfs diff=lfs merge=lfs -text
Diffusion-cuda ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b2cc0ad238434f2130d1820d0196f321429392054b9596d9b44b8676782d57dd
3
+ size 51480614
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch==2.2.1
2
+ torchvision==0.17.1
3
+ streamlit==1.33.0
4
+ numpy
5
+ tqdm
source.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision
5
+
6
+ class Block(nn.Module):
7
+ def __init__(self, in_channels=128, size=32):
8
+ super(Block, self).__init__()
9
+
10
+ self.conv_param = nn.Conv2d(in_channels=in_channels, out_channels=128, kernel_size=3, padding=1)
11
+ self.conv_out = nn.Conv2d(in_channels=in_channels, out_channels=128, kernel_size=3, padding=1)
12
+
13
+ self.dense_ts = nn.Linear(192, 128)
14
+
15
+ self.layer_norm = nn.LayerNorm([128, size, size])
16
+
17
+ def forward(self, x_img, x_ts):
18
+ x_parameter = F.relu(self.conv_param(x_img))
19
+
20
+ time_parameter = F.relu(self.dense_ts(x_ts))
21
+ time_parameter = time_parameter.view(-1, 128, 1, 1)
22
+ x_parameter = x_parameter * time_parameter
23
+
24
+ x_out = self.conv_out(x_img)
25
+ x_out = x_out + x_parameter
26
+ x_out = F.relu(self.layer_norm(x_out))
27
+
28
+ return x_out
29
+
30
+
31
+ class Model(nn.Module):
32
+ def __init__(self):
33
+ super(Model, self).__init__()
34
+
35
+ self.l_ts = nn.Sequential(
36
+ nn.Linear(1, 192),
37
+ nn.LayerNorm([192]),
38
+ nn.ReLU(),
39
+ )
40
+
41
+ self.down_x32 = Block(in_channels=3, size=32)
42
+ self.down_x16 = Block(size=16)
43
+ self.down_x8 = Block(size=8)
44
+ self.down_x4 = Block(size=4)
45
+
46
+ self.mlp = nn.Sequential(
47
+ nn.Linear(2240, 128),
48
+ nn.LayerNorm([128]),
49
+ nn.ReLU(),
50
+
51
+ nn.Linear(128, 32 * 4 * 4), # make [-1, 32, 4, 4]
52
+ nn.LayerNorm([32 * 4 * 4]),
53
+ nn.ReLU(),
54
+ )
55
+
56
+ self.up_x4 = Block(in_channels=32 + 128, size=4)
57
+ self.up_x8 = Block(in_channels=256, size=8)
58
+ self.up_x16 = Block(in_channels=256, size=16)
59
+ self.up_x32 = Block(in_channels=256, size=32)
60
+
61
+ self.cnn_output = nn.Conv2d(in_channels=128, out_channels=3, kernel_size=1, padding=0)
62
+
63
+ # make optimizer
64
+ self.opt = torch.optim.Adam(self.parameters(), lr=0.0008)
65
+
66
+ def forward(self, x, x_ts):
67
+ x_ts = self.l_ts(x_ts)
68
+
69
+ # ----- left ( down ) -----
70
+ blocks = [
71
+ self.down_x32,
72
+ self.down_x16,
73
+ self.down_x8,
74
+ self.down_x4,
75
+ ]
76
+ x_left_layers = []
77
+ for i, block in enumerate(blocks):
78
+ x = block(x, x_ts)
79
+ x_left_layers.append(x)
80
+ if i < len(blocks) - 1:
81
+ x = F.max_pool2d(x, 2)
82
+
83
+ # ----- MLP -----
84
+ x = x.view(-1, 128 * 4 * 4)
85
+ x = torch.cat([x, x_ts], dim=1)
86
+ x = self.mlp(x)
87
+ x = x.view(-1, 32, 4, 4)
88
+
89
+ # ----- right ( up ) -----
90
+ blocks = [
91
+ self.up_x4,
92
+ self.up_x8,
93
+ self.up_x16,
94
+ self.up_x32,
95
+ ]
96
+
97
+ for i, block in enumerate(blocks):
98
+ # cat left
99
+ x_left = x_left_layers[len(blocks) - i - 1]
100
+ x = torch.cat([x, x_left], dim=1)
101
+
102
+ x = block(x, x_ts)
103
+ if i < len(blocks) - 1:
104
+ x = F.interpolate(x, scale_factor=2, mode='bilinear')
105
+
106
+ # ----- output -----
107
+ x = self.cnn_output(x)
108
+
109
+ return x