Fgdfgfthgr commited on
Commit
17abc30
·
verified ·
1 Parent(s): 1609d46

Upload 2 files

Browse files
Files changed (2) hide show
  1. Final_8.ckpt +3 -0
  2. minimal_script.py +143 -0
Final_8.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e55bad6f0eff1fe9cf966005ce2c7bae2bacb8dfc80ac524428a179bfd757782
3
+ size 12632827
minimal_script.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import math
4
+ import numpy as np
5
+ import torch.nn as nn
6
+ import lightning.pytorch as pl
7
+ import imageio
8
+ from torchvision.transforms import v2
9
+
10
+
11
+ class BasicBlock(nn.Module):
12
+ def __init__(self, channels, kernel_size=(3,3)):
13
+ super().__init__()
14
+ layers = []
15
+ num_conv = len(channels)-1
16
+ for i in range(num_conv):
17
+ layers.append(nn.Conv2d(channels[i], channels[i+1],
18
+ kernel_size=kernel_size, padding='same', padding_mode='reflect', bias=False))
19
+ layers.append(nn.InstanceNorm2d(channels[i+1], affine=False))
20
+ layers.append(nn.ReLU())
21
+ self.operations = nn.Sequential(*layers)
22
+
23
+ def forward(self, x):
24
+ return self.operations(x)
25
+
26
+
27
+ class ResBlock(nn.Module):
28
+ def __init__(self, in_channels, out_channels, kernel_size=(3,3), num_conv=2):
29
+ super().__init__()
30
+ layers = []
31
+ if in_channels == out_channels:
32
+ self.mapping = nn.Identity()
33
+ else:
34
+ self.mapping = nn.Conv2d(in_channels, out_channels, 1)
35
+ for i in range(num_conv):
36
+ layers.append(nn.Conv2d(in_channels if i == 0 else out_channels, out_channels,
37
+ kernel_size=kernel_size, padding='same', padding_mode='reflect', bias=False))
38
+ layers.append(nn.InstanceNorm2d(out_channels, affine=False))
39
+ layers.append(nn.ReLU())
40
+ self.operations = nn.Sequential(*layers)
41
+
42
+ def forward(self, x):
43
+ return (self.mapping(x) + self.operations(x)) / math.sqrt(2)
44
+
45
+
46
+ class ConvPool(nn.Module):
47
+ def __init__(self, in_channels, out_channels):
48
+ super().__init__()
49
+ layers = []
50
+ layers.append(nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False, padding_mode='reflect'))
51
+ layers.append(nn.InstanceNorm2d(out_channels, affine=False))
52
+ layers.append(nn.ReLU(inplace=True))
53
+ self.operations = nn.Sequential(*layers)
54
+
55
+ def forward(self, x):
56
+ return self.operations(x)
57
+
58
+
59
+ class EmbeddingNetworkSmall(nn.Module):
60
+ def __init__(self):
61
+ super(EmbeddingNetworkSmall, self).__init__()
62
+ self.conv1 = BasicBlock((3, 8, 16), (3, 3))
63
+ self.pool1 = ConvPool(16, 32) # 2
64
+ self.conv2 = ResBlock(32, 32, (3, 3), 3)
65
+ self.pool2 = ConvPool(32, 64) # 4
66
+ self.conv3 = ResBlock(64, 64, (3, 3), 3)
67
+ self.drop1 = nn.Dropout2d(p=0.25)
68
+ self.pool3 = ConvPool(64, 128) # 8
69
+ self.conv4 = ResBlock(128, 128, (3, 3), 3)
70
+ self.adpool = nn.AdaptiveAvgPool2d(1)
71
+ self.poolnorm = nn.LayerNorm(128, elementwise_affine=False)
72
+ self.flatten = nn.Flatten()
73
+ self.drop2 = nn.Dropout(p=0.33)
74
+ self.fc1 = nn.Linear(128, 128, bias=False)
75
+ self.fc1norm = nn.LayerNorm(128, elementwise_affine=False)
76
+ self.act = nn.ReLU()
77
+ self.fc2 = nn.Linear(128, 128, bias=False)
78
+ self.fc2norm = nn.LayerNorm(128, elementwise_affine=False)
79
+ self.fc3 = nn.Linear(128, 8)
80
+
81
+ self.use_checkpoint = False
82
+
83
+ def forward(self, x):
84
+ x = self.pool1(self.conv1(x))
85
+ x = self.pool2(self.conv2(x))
86
+ x = self.pool3(self.drop1(self.conv3(x)))
87
+ x = self.conv4(x)
88
+
89
+ x = self.adpool(x)
90
+ x = self.poolnorm(self.flatten(x))
91
+ x = self.act(self.drop2(x))
92
+ x = self.act(self.fc1norm(self.fc1(x)))
93
+ x = self.act(self.fc2norm(self.fc2(x)))
94
+ x = self.fc3(x)
95
+ return x
96
+
97
+
98
+
99
+ class PLModule(pl.LightningModule):
100
+ def __init__(self):
101
+ super().__init__()
102
+ self.save_hyperparameters()
103
+ self.network = EmbeddingNetworkSmall()
104
+
105
+ def forward(self, x):
106
+ return self.network(x)
107
+
108
+
109
+ def down_to_1k(img, size=1024):
110
+ h, w = img.shape[1], img.shape[2]
111
+ area = h * w
112
+ if area > size ** 2:
113
+ scale_factor = (size ** 2 / area) ** 0.5
114
+ new_h = math.floor(h * scale_factor)
115
+ new_w = math.floor(w * scale_factor)
116
+ img = v2.functional.resize(img, (new_w, new_h))
117
+ return img
118
+
119
+
120
+ def closest_interval(img, interval=8):
121
+ c, h, w = img.shape
122
+ new_h = h - (h % interval) if h % interval != 0 else h
123
+ new_w = w - (w % interval) if w % interval != 0 else w
124
+ h_start = (h - new_h) // 2
125
+ w_start = (w - new_w) // 2
126
+ new_h, new_w = max(new_h, interval), max(new_w, interval)
127
+ return img[:, h_start:h_start + new_h, w_start:w_start + new_w]
128
+
129
+
130
+ if __name__ == '__main__':
131
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
132
+ model = PLModule.load_from_checkpoint('Final_8.ckpt')
133
+ model.to(device)
134
+ model.eval()
135
+
136
+ img = imageio.v3.imread('images_for_style_embedding/6857740.webp').copy()
137
+ img = torch.from_numpy(img).permute(2, 0, 1)
138
+ img = closest_interval(down_to_1k(img))
139
+ img = 2*(img/255)-1
140
+ img = img.unsqueeze(0).to(device)
141
+
142
+ pred = model(img)
143
+ print(pred)