Raziel1234 commited on
Commit
77f1528
verified
1 Parent(s): 52a9c4e

Update modeling_pixel.py

Browse files
Files changed (1) hide show
  1. modeling_pixel.py +11 -9
modeling_pixel.py CHANGED
@@ -1,7 +1,7 @@
1
  import torch
2
  import torch.nn as nn
3
  from transformers import PreTrainedModel
4
- from .configuration_pixel import TopAIImageConfig # 转讬拽讜谉 讛砖诐 讻讗谉
5
 
6
  class ResidualBlock(nn.Module):
7
  def __init__(self, channels):
@@ -18,19 +18,21 @@ class ResidualBlock(nn.Module):
18
 
19
  class TopAIImageGenerator(PreTrainedModel):
20
  config_class = TopAIImageConfig
 
21
 
22
  def __init__(self, config):
23
  super().__init__(config)
24
- self.text_projection = nn.Linear(config.input_dim, 4 * 4 * 1024)
 
25
 
26
  self.decoder = nn.Sequential(
27
- self._upsample(1024, 512),
28
- ResidualBlock(512),
29
- self._upsample(512, 256),
30
  ResidualBlock(256),
31
- self._upsample(256, 128),
32
- self._upsample(128, 64),
33
- self._upsample(64, 32),
 
 
34
  nn.Conv2d(32, config.image_channels, kernel_size=3, padding=1),
35
  nn.Tanh()
36
  )
@@ -44,5 +46,5 @@ class TopAIImageGenerator(PreTrainedModel):
44
 
45
  def forward(self, text_embeddings):
46
  x = self.text_projection(text_embeddings)
47
- x = x.view(-1, 1024, 4, 4)
48
  return self.decoder(x)
 
1
  import torch
2
  import torch.nn as nn
3
  from transformers import PreTrainedModel
4
+ from .configuration_pixel import TopAIImageConfig
5
 
6
  class ResidualBlock(nn.Module):
7
  def __init__(self, channels):
 
18
 
19
  class TopAIImageGenerator(PreTrainedModel):
20
  config_class = TopAIImageConfig
21
+ all_tied_weights_keys = [] # 诪讜谞注 讗转 讛-AttributeError 讘-Transformers
22
 
23
  def __init__(self, config):
24
  super().__init__(config)
25
+ # 转讬拽讜谉 讛诪讬诪讚 诇-4*4*512 讘讛转讗诐 诇-Checkpoint 砖诇讻诐
26
+ self.text_projection = nn.Linear(config.input_dim, 4 * 4 * config.hidden_dim)
27
 
28
  self.decoder = nn.Sequential(
29
+ self._upsample(config.hidden_dim, 256), # 4 -> 8
 
 
30
  ResidualBlock(256),
31
+ self._upsample(256, 128), # 8 -> 16
32
+ ResidualBlock(128),
33
+ self._upsample(128, 64), # 16 -> 32
34
+ self._upsample(64, 32), # 32 -> 64
35
+ self._upsample(32, 32), # 64 -> 128 (转讬拽讜谉 诇诪讬诪讚 注拽讘讬)
36
  nn.Conv2d(32, config.image_channels, kernel_size=3, padding=1),
37
  nn.Tanh()
38
  )
 
46
 
47
  def forward(self, text_embeddings):
48
  x = self.text_projection(text_embeddings)
49
+ x = x.view(-1, self.config.hidden_dim, 4, 4)
50
  return self.decoder(x)