Raziel1234 commited on
Commit
b6b44ab
·
verified ·
1 Parent(s): f27f371

Create modeling_pixel.py

Browse files
Files changed (1) hide show
  1. modeling_pixel.py +48 -0
modeling_pixel.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedModel
4
+ from .configuration_pixel import PixelConfig
5
+
6
+ class ResidualBlock(nn.Module):
7
+ def __init__(self, channels):
8
+ super().__init__()
9
+ self.block = nn.Sequential(
10
+ nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False),
11
+ nn.BatchNorm2d(channels),
12
+ nn.ReLU(True),
13
+ nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False),
14
+ nn.BatchNorm2d(channels)
15
+ )
16
+ def forward(self, x):
17
+ return x + self.block(x)
18
+
19
+ class PixelGenerator(PreTrainedModel):
20
+ config_class = PixelConfig
21
+
22
+ def __init__(self, config):
23
+ super().__init__(config)
24
+ self.text_projection = nn.Linear(config.input_dim, 4 * 4 * 1024)
25
+
26
+ self.decoder = nn.Sequential(
27
+ self._upsample(1024, 512), # 4x4 -> 8x8
28
+ ResidualBlock(512),
29
+ self._upsample(512, 256), # 8x8 -> 16x16
30
+ ResidualBlock(256),
31
+ self._upsample(256, 128), # 16x16 -> 32x32
32
+ self._upsample(128, 64), # 32x32 -> 64x64
33
+ self._upsample(64, 32), # 64x64 -> 128x128
34
+ nn.Conv2d(32, config.image_channels, kernel_size=3, padding=1),
35
+ nn.Tanh()
36
+ )
37
+
38
+ def _upsample(self, i, o):
39
+ return nn.Sequential(
40
+ nn.ConvTranspose2d(i, o, 4, 2, 1, bias=False),
41
+ nn.BatchNorm2d(o),
42
+ nn.ReLU(True)
43
+ )
44
+
45
+ def forward(self, text_embeddings):
46
+ x = self.text_projection(text_embeddings)
47
+ x = x.view(-1, 1024, 4, 4)
48
+ return self.decoder(x)