twarner commited on
Commit
4fe9c3a
·
1 Parent(s): 916a1f7

Fix decoder architecture to match v2 training

Browse files
Files changed (1) hide show
  1. app.py +16 -11
app.py CHANGED
@@ -54,16 +54,19 @@ class GcodeDecoder(nn.Module):
54
  self.token_embed = nn.Embedding(config.vocab_size, config.hidden_size)
55
  self.pos_embed = nn.Embedding(config.max_seq_len, config.hidden_size)
56
 
57
- decoder_layer = nn.TransformerDecoderLayer(
58
- d_model=config.hidden_size,
59
- nhead=config.num_heads,
60
- dim_feedforward=config.hidden_size * 4,
61
- dropout=config.dropout,
62
- activation='gelu',
63
- batch_first=True,
64
- norm_first=True,
65
- )
66
- self.decoder = nn.TransformerDecoder(decoder_layer, config.num_layers)
 
 
 
67
 
68
  self.ln_f = nn.LayerNorm(config.hidden_size)
69
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
@@ -82,7 +85,9 @@ class GcodeDecoder(nn.Module):
82
 
83
  causal_mask = nn.Transformer.generate_square_subsequent_mask(seq_len, device=device)
84
 
85
- x = self.decoder(x, memory, tgt_mask=causal_mask)
 
 
86
  x = self.ln_f(x)
87
  return self.lm_head(x)
88
 
 
54
  self.token_embed = nn.Embedding(config.vocab_size, config.hidden_size)
55
  self.pos_embed = nn.Embedding(config.max_seq_len, config.hidden_size)
56
 
57
+ # Individual layers (matches v2 training architecture)
58
+ self.layers = nn.ModuleList([
59
+ nn.TransformerDecoderLayer(
60
+ d_model=config.hidden_size,
61
+ nhead=config.num_heads,
62
+ dim_feedforward=config.hidden_size * 4,
63
+ dropout=config.dropout,
64
+ activation='gelu',
65
+ batch_first=True,
66
+ norm_first=True,
67
+ )
68
+ for _ in range(config.num_layers)
69
+ ])
70
 
71
  self.ln_f = nn.LayerNorm(config.hidden_size)
72
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 
85
 
86
  causal_mask = nn.Transformer.generate_square_subsequent_mask(seq_len, device=device)
87
 
88
+ for layer in self.layers:
89
+ x = layer(x, memory, tgt_mask=causal_mask)
90
+
91
  x = self.ln_f(x)
92
  return self.lm_head(x)
93