Add print statements
Browse files- modeling_cogvlm.py +6 -2
modeling_cogvlm.py
CHANGED
|
@@ -290,12 +290,14 @@ class CogVLMDecoderLayer(nn.Module):
|
|
| 290 |
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 291 |
output_attentions: Optional[bool] = False,
|
| 292 |
use_cache: Optional[bool] = False,
|
|
|
|
| 293 |
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 294 |
residual = hidden_states
|
| 295 |
|
| 296 |
hidden_states = self.input_layernorm(hidden_states)
|
| 297 |
|
| 298 |
-
|
|
|
|
| 299 |
|
| 300 |
# Self Attention
|
| 301 |
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
|
@@ -308,7 +310,8 @@ class CogVLMDecoderLayer(nn.Module):
|
|
| 308 |
use_cache=use_cache,
|
| 309 |
)
|
| 310 |
|
| 311 |
-
|
|
|
|
| 312 |
|
| 313 |
hidden_states = residual + hidden_states
|
| 314 |
|
|
@@ -539,6 +542,7 @@ class CogVLMModel(CogVLMPreTrainedModel):
|
|
| 539 |
past_key_value=past_key_value,
|
| 540 |
output_attentions=output_attentions,
|
| 541 |
use_cache=use_cache,
|
|
|
|
| 542 |
)
|
| 543 |
hidden_states = layer_outputs[0]
|
| 544 |
|
|
|
|
| 290 |
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 291 |
output_attentions: Optional[bool] = False,
|
| 292 |
use_cache: Optional[bool] = False,
|
| 293 |
+
print_values = False,
|
| 294 |
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 295 |
residual = hidden_states
|
| 296 |
|
| 297 |
hidden_states = self.input_layernorm(hidden_states)
|
| 298 |
|
| 299 |
+
if print_values:
|
| 300 |
+
print("Hidden states before self attention:", hidden_states[0,:3,:3])
|
| 301 |
|
| 302 |
# Self Attention
|
| 303 |
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
|
|
|
| 310 |
use_cache=use_cache,
|
| 311 |
)
|
| 312 |
|
| 313 |
+
if print_values:
|
| 314 |
+
print("Hidden states after self attention:", hidden_states[0,:3,:3])
|
| 315 |
|
| 316 |
hidden_states = residual + hidden_states
|
| 317 |
|
|
|
|
| 542 |
past_key_value=past_key_value,
|
| 543 |
output_attentions=output_attentions,
|
| 544 |
use_cache=use_cache,
|
| 545 |
+
print_values=idx in [0, 1, 2],
|
| 546 |
)
|
| 547 |
hidden_states = layer_outputs[0]
|
| 548 |
|