Update modeling_diffusion.py
Browse files- modeling_diffusion.py +10 -1
modeling_diffusion.py
CHANGED
|
@@ -5,6 +5,15 @@ class DiffusionTextModel(nn.Module, PyTorchModelHubMixin):
|
|
| 5 |
def __init__(self, vocab_size, max_seq_len, max_time_steps,
|
| 6 |
embed_dim=128, n_layers=4, n_heads=4):
|
| 7 |
super().__init__()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
self.token_emb = nn.Embedding(vocab_size, embed_dim)
|
| 9 |
self.pos_emb = nn.Embedding(max_seq_len, embed_dim)
|
| 10 |
self.time_emb = nn.Embedding(max_time_steps+1, embed_dim)
|
|
@@ -23,4 +32,4 @@ class DiffusionTextModel(nn.Module, PyTorchModelHubMixin):
|
|
| 23 |
tim = self.time_emb(t).unsqueeze(1).expand(B, L, -1)
|
| 24 |
h = tok + pos + tim
|
| 25 |
h = self.transformer(h.transpose(0,1)).transpose(0,1)
|
| 26 |
-
return self.out(h)
|
|
|
|
| 5 |
def __init__(self, vocab_size, max_seq_len, max_time_steps,
|
| 6 |
embed_dim=128, n_layers=4, n_heads=4):
|
| 7 |
super().__init__()
|
| 8 |
+
self.config = {
|
| 9 |
+
"vocab_size": vocab_size,
|
| 10 |
+
"max_seq_len": max_seq_len,
|
| 11 |
+
"max_time_steps": max_time_steps,
|
| 12 |
+
"embed_dim": embed_dim,
|
| 13 |
+
"n_layers": n_layers,
|
| 14 |
+
"n_heads": n_heads
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
self.token_emb = nn.Embedding(vocab_size, embed_dim)
|
| 18 |
self.pos_emb = nn.Embedding(max_seq_len, embed_dim)
|
| 19 |
self.time_emb = nn.Embedding(max_time_steps+1, embed_dim)
|
|
|
|
| 32 |
tim = self.time_emb(t).unsqueeze(1).expand(B, L, -1)
|
| 33 |
h = tok + pos + tim
|
| 34 |
h = self.transformer(h.transpose(0,1)).transpose(0,1)
|
| 35 |
+
return self.out(h)
|