Update modeling_pixel.py
Browse files- modeling_pixel.py +7 -7
modeling_pixel.py
CHANGED
|
@@ -18,21 +18,21 @@ class ResidualBlock(nn.Module):
|
|
| 18 |
|
| 19 |
class TopAIImageGenerator(PreTrainedModel):
|
| 20 |
config_class = TopAIImageConfig
|
| 21 |
-
|
|
|
|
| 22 |
|
| 23 |
def __init__(self, config):
|
| 24 |
super().__init__(config)
|
| 25 |
-
#
|
| 26 |
self.text_projection = nn.Linear(config.input_dim, 4 * 4 * config.hidden_dim)
|
| 27 |
|
| 28 |
self.decoder = nn.Sequential(
|
| 29 |
-
self._upsample(config.hidden_dim, 256), #
|
| 30 |
ResidualBlock(256),
|
| 31 |
-
self._upsample(256, 128), #
|
| 32 |
ResidualBlock(128),
|
| 33 |
-
self._upsample(128, 64), #
|
| 34 |
-
self._upsample(64, 32), #
|
| 35 |
-
self._upsample(32, 32), # 64 -> 128 (转讬拽讜谉 诇诪讬诪讚 注拽讘讬)
|
| 36 |
nn.Conv2d(32, config.image_channels, kernel_size=3, padding=1),
|
| 37 |
nn.Tanh()
|
| 38 |
)
|
|
|
|
| 18 |
|
| 19 |
class TopAIImageGenerator(PreTrainedModel):
|
| 20 |
config_class = TopAIImageConfig
|
| 21 |
+
# 转讬拽讜谉 讛砖讙讬讗讛: 诪讙讚讬专讬诐 讻专砖讬诪讛 专讬拽讛, Transformers 讻讘专 讬住转讚专 注诐 讝讛
|
| 22 |
+
all_tied_weights_keys = []
|
| 23 |
|
| 24 |
def __init__(self, config):
|
| 25 |
super().__init__(config)
|
| 26 |
+
# 住讬谞讻专讜谉 讛诪讬诪讚讬诐 诇驻讬 讛-Checkpoint 砖诇讻诐
|
| 27 |
self.text_projection = nn.Linear(config.input_dim, 4 * 4 * config.hidden_dim)
|
| 28 |
|
| 29 |
self.decoder = nn.Sequential(
|
| 30 |
+
self._upsample(config.hidden_dim, 256), # 4x4 -> 8x8
|
| 31 |
ResidualBlock(256),
|
| 32 |
+
self._upsample(256, 128), # 8x8 -> 16x16
|
| 33 |
ResidualBlock(128),
|
| 34 |
+
self._upsample(128, 64), # 16x16 -> 32x32
|
| 35 |
+
self._upsample(64, 32), # 32x32 -> 64x64
|
|
|
|
| 36 |
nn.Conv2d(32, config.image_channels, kernel_size=3, padding=1),
|
| 37 |
nn.Tanh()
|
| 38 |
)
|