twarner commited on
Commit
df5c606
·
1 Parent(s): 4fe9c3a

Fix causal mask dtype for float16

Browse files
Files changed (1) hide show
  1. app.py +3 -1
app.py CHANGED
@@ -75,6 +75,7 @@ class GcodeDecoder(nn.Module):
75
  def forward(self, latent: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor:
76
  batch_size, seq_len = input_ids.shape
77
  device = input_ids.device
 
78
 
79
  latent_flat = latent.view(batch_size, -1)
80
  memory = self.latent_proj(latent_flat)
@@ -83,7 +84,8 @@ class GcodeDecoder(nn.Module):
83
  positions = torch.arange(seq_len, device=device)
84
  x = self.token_embed(input_ids) + self.pos_embed(positions)
85
 
86
- causal_mask = nn.Transformer.generate_square_subsequent_mask(seq_len, device=device)
 
87
 
88
  for layer in self.layers:
89
  x = layer(x, memory, tgt_mask=causal_mask)
 
75
  def forward(self, latent: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor:
76
  batch_size, seq_len = input_ids.shape
77
  device = input_ids.device
78
+ dtype = latent.dtype
79
 
80
  latent_flat = latent.view(batch_size, -1)
81
  memory = self.latent_proj(latent_flat)
 
84
  positions = torch.arange(seq_len, device=device)
85
  x = self.token_embed(input_ids) + self.pos_embed(positions)
86
 
87
+ # Causal mask must match dtype for attention
88
+ causal_mask = nn.Transformer.generate_square_subsequent_mask(seq_len, device=device, dtype=dtype)
89
 
90
  for layer in self.layers:
91
  x = layer(x, memory, tgt_mask=causal_mask)