Update modeling_pixel.py
Browse files- modeling_pixel.py +14 -10
modeling_pixel.py
CHANGED
|
@@ -18,22 +18,26 @@ class ResidualBlock(nn.Module):
|
|
| 18 |
|
| 19 |
class TopAIImageGenerator(PreTrainedModel):
|
| 20 |
config_class = TopAIImageConfig
|
| 21 |
-
|
| 22 |
-
|
|
|
|
| 23 |
|
| 24 |
def __init__(self, config):
|
| 25 |
super().__init__(config)
|
| 26 |
-
# 住讬谞讻专讜谉 讛诪讬诪讚讬诐 诇驻讬 讛-Checkpoint 砖诇讻诐
|
| 27 |
self.text_projection = nn.Linear(config.input_dim, 4 * 4 * config.hidden_dim)
|
| 28 |
|
|
|
|
| 29 |
self.decoder = nn.Sequential(
|
| 30 |
-
self._upsample(config.hidden_dim, 256), #
|
| 31 |
-
ResidualBlock(256),
|
| 32 |
-
self._upsample(256, 128), #
|
| 33 |
-
ResidualBlock(128),
|
| 34 |
-
self._upsample(128, 64), #
|
| 35 |
-
self._upsample(64, 32), #
|
| 36 |
-
nn.Conv2d(32,
|
|
|
|
|
|
|
|
|
|
| 37 |
nn.Tanh()
|
| 38 |
)
|
| 39 |
|
|
|
|
| 18 |
|
| 19 |
class TopAIImageGenerator(PreTrainedModel):
|
| 20 |
config_class = TopAIImageConfig
|
| 21 |
+
|
| 22 |
+
# 转讬拽讜谉 拽专讬讟讬: 讛讙讚专讛 讻诪讬诇讜谉 专讬拽 讻讚讬 诇诪谞讜注 讗转 讛-AttributeError
|
| 23 |
+
all_tied_weights_keys = {}
|
| 24 |
|
| 25 |
def __init__(self, config):
|
| 26 |
super().__init__(config)
|
|
|
|
| 27 |
self.text_projection = nn.Linear(config.input_dim, 4 * 4 * config.hidden_dim)
|
| 28 |
|
| 29 |
+
# 讘谞讬讬讛 诪讞讚砖 砖诇 讛-Decoder 讻讚讬 诇讛转讗讬诐 诇诪讘谞讛 讛-Checkpoint 砖诇讻诐
|
| 30 |
self.decoder = nn.Sequential(
|
| 31 |
+
self._upsample(config.hidden_dim, 256), # 0
|
| 32 |
+
ResidualBlock(256), # 1
|
| 33 |
+
self._upsample(256, 128), # 2
|
| 34 |
+
ResidualBlock(128), # 3
|
| 35 |
+
self._upsample(128, 64), # 4
|
| 36 |
+
self._upsample(64, 32), # 5
|
| 37 |
+
nn.Conv2d(32, 32, kernel_size=3, padding=1), # 6
|
| 38 |
+
nn.BatchNorm2d(32), # 7
|
| 39 |
+
nn.ReLU(True), # 8
|
| 40 |
+
nn.Conv2d(32, config.image_channels, kernel_size=3, padding=1), # 9
|
| 41 |
nn.Tanh()
|
| 42 |
)
|
| 43 |
|