yasserrmd commited on
Commit
bb81f15
·
verified ·
1 Parent(s): d60b639

Update modeling_diffusion.py

Browse files
Files changed (1) hide show
  1. modeling_diffusion.py +10 -1
modeling_diffusion.py CHANGED
@@ -5,6 +5,15 @@ class DiffusionTextModel(nn.Module, PyTorchModelHubMixin):
5
  def __init__(self, vocab_size, max_seq_len, max_time_steps,
6
  embed_dim=128, n_layers=4, n_heads=4):
7
  super().__init__()
 
 
 
 
 
 
 
 
 
8
  self.token_emb = nn.Embedding(vocab_size, embed_dim)
9
  self.pos_emb = nn.Embedding(max_seq_len, embed_dim)
10
  self.time_emb = nn.Embedding(max_time_steps+1, embed_dim)
@@ -23,4 +32,4 @@ class DiffusionTextModel(nn.Module, PyTorchModelHubMixin):
23
  tim = self.time_emb(t).unsqueeze(1).expand(B, L, -1)
24
  h = tok + pos + tim
25
  h = self.transformer(h.transpose(0,1)).transpose(0,1)
26
- return self.out(h)
 
5
  def __init__(self, vocab_size, max_seq_len, max_time_steps,
6
  embed_dim=128, n_layers=4, n_heads=4):
7
  super().__init__()
8
+ self.config = {
9
+ "vocab_size": vocab_size,
10
+ "max_seq_len": max_seq_len,
11
+ "max_time_steps": max_time_steps,
12
+ "embed_dim": embed_dim,
13
+ "n_layers": n_layers,
14
+ "n_heads": n_heads
15
+ }
16
+
17
  self.token_emb = nn.Embedding(vocab_size, embed_dim)
18
  self.pos_emb = nn.Embedding(max_seq_len, embed_dim)
19
  self.time_emb = nn.Embedding(max_time_steps+1, embed_dim)
 
32
  tim = self.time_emb(t).unsqueeze(1).expand(B, L, -1)
33
  h = tok + pos + tim
34
  h = self.transformer(h.transpose(0,1)).transpose(0,1)
35
+ return self.out(h)