Fix causal mask dtype for float16
Browse files
app.py
CHANGED
|
@@ -75,6 +75,7 @@ class GcodeDecoder(nn.Module):
|
|
| 75 |
def forward(self, latent: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor:
|
| 76 |
batch_size, seq_len = input_ids.shape
|
| 77 |
device = input_ids.device
|
|
|
|
| 78 |
|
| 79 |
latent_flat = latent.view(batch_size, -1)
|
| 80 |
memory = self.latent_proj(latent_flat)
|
|
@@ -83,7 +84,8 @@ class GcodeDecoder(nn.Module):
|
|
| 83 |
positions = torch.arange(seq_len, device=device)
|
| 84 |
x = self.token_embed(input_ids) + self.pos_embed(positions)
|
| 85 |
|
| 86 |
-
|
|
|
|
| 87 |
|
| 88 |
for layer in self.layers:
|
| 89 |
x = layer(x, memory, tgt_mask=causal_mask)
|
|
|
|
| 75 |
def forward(self, latent: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor:
|
| 76 |
batch_size, seq_len = input_ids.shape
|
| 77 |
device = input_ids.device
|
| 78 |
+
dtype = latent.dtype
|
| 79 |
|
| 80 |
latent_flat = latent.view(batch_size, -1)
|
| 81 |
memory = self.latent_proj(latent_flat)
|
|
|
|
| 84 |
positions = torch.arange(seq_len, device=device)
|
| 85 |
x = self.token_embed(input_ids) + self.pos_embed(positions)
|
| 86 |
|
| 87 |
+
# Causal mask must match dtype for attention
|
| 88 |
+
causal_mask = nn.Transformer.generate_square_subsequent_mask(seq_len, device=device, dtype=dtype)
|
| 89 |
|
| 90 |
for layer in self.layers:
|
| 91 |
x = layer(x, memory, tgt_mask=causal_mask)
|