Update chatNT.py
Browse files
chatNT.py
CHANGED
|
@@ -426,6 +426,7 @@ class TorchBioBrainDecoder(nn.Module):
|
|
| 426 |
)
|
| 427 |
|
| 428 |
# Regular GPT pass through
|
|
|
|
| 429 |
embeddings = self.gpt_model.apply_transformer_layers(tokens_embeddings)
|
| 430 |
embeddings = self.gpt_model.final_norm(embeddings)
|
| 431 |
|
|
@@ -885,6 +886,7 @@ class TorchGptGroupedQueryAttention(nn.Module):
|
|
| 885 |
value_inputs: torch.Tensor,
|
| 886 |
attention_mask: torch.Tensor = None,
|
| 887 |
) -> torch.Tensor:
|
|
|
|
| 888 |
batch_size, seq_len, _ = query_inputs.shape
|
| 889 |
|
| 890 |
queries = self.query_linear(query_inputs).view( # noqa
|
|
@@ -966,6 +968,7 @@ class TorchGptDecoder(nn.Module):
|
|
| 966 |
if attention_mask is None:
|
| 967 |
attention_mask = build_causal_attention_mask(1, embeddings.shape[1])
|
| 968 |
for layer in self.layers:
|
|
|
|
| 969 |
embeddings = layer(embeddings, attention_mask)
|
| 970 |
|
| 971 |
return embeddings
|
|
|
|
| 426 |
)
|
| 427 |
|
| 428 |
# Regular GPT pass through
|
| 429 |
+
print("(debug) tokens embeddings shape : ", tokens_embeddings.shape)
|
| 430 |
embeddings = self.gpt_model.apply_transformer_layers(tokens_embeddings)
|
| 431 |
embeddings = self.gpt_model.final_norm(embeddings)
|
| 432 |
|
|
|
|
| 886 |
value_inputs: torch.Tensor,
|
| 887 |
attention_mask: torch.Tensor = None,
|
| 888 |
) -> torch.Tensor:
|
| 889 |
+
print("(debug) Query input shape : ", query_inputs.shape)
|
| 890 |
batch_size, seq_len, _ = query_inputs.shape
|
| 891 |
|
| 892 |
queries = self.query_linear(query_inputs).view( # noqa
|
|
|
|
| 968 |
if attention_mask is None:
|
| 969 |
attention_mask = build_causal_attention_mask(1, embeddings.shape[1])
|
| 970 |
for layer in self.layers:
|
| 971 |
+
print("Embedding shape in apply_transformer_layers : ", embeddings.shape)
|
| 972 |
embeddings = layer(embeddings, attention_mask)
|
| 973 |
|
| 974 |
return embeddings
|