Add print statements
Browse files- modeling_cogvlm.py +7 -7
modeling_cogvlm.py
CHANGED
|
@@ -241,7 +241,7 @@ class VisionExpertAttention(nn.Module):
|
|
| 241 |
key_states = self._transpose_for_scores(key_states) # B, H, L, HD
|
| 242 |
value_states = self._transpose_for_scores(value_states) # B, H, L, HD
|
| 243 |
|
| 244 |
-
if print_values:
|
| 245 |
|
| 246 |
# torch.save(query_states, "query_states.pt")
|
| 247 |
# torch.save(key_states, "key_states.pt")
|
|
@@ -325,13 +325,13 @@ class CogVLMDecoderLayer(nn.Module):
|
|
| 325 |
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 326 |
residual = hidden_states
|
| 327 |
|
| 328 |
-
if print_values:
|
| 329 |
-
|
| 330 |
|
| 331 |
hidden_states = self.input_layernorm(hidden_states)
|
| 332 |
|
| 333 |
-
if print_values:
|
| 334 |
-
|
| 335 |
|
| 336 |
# Self Attention
|
| 337 |
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
|
@@ -345,8 +345,8 @@ class CogVLMDecoderLayer(nn.Module):
|
|
| 345 |
print_values=print_values,
|
| 346 |
)
|
| 347 |
|
| 348 |
-
if print_values:
|
| 349 |
-
|
| 350 |
|
| 351 |
hidden_states = residual + hidden_states
|
| 352 |
|
|
|
|
| 241 |
key_states = self._transpose_for_scores(key_states) # B, H, L, HD
|
| 242 |
value_states = self._transpose_for_scores(value_states) # B, H, L, HD
|
| 243 |
|
| 244 |
+
# if print_values:
|
| 245 |
|
| 246 |
# torch.save(query_states, "query_states.pt")
|
| 247 |
# torch.save(key_states, "key_states.pt")
|
|
|
|
| 325 |
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 326 |
residual = hidden_states
|
| 327 |
|
| 328 |
+
# if print_values:
|
| 329 |
+
# print("Hidden states before RMS norm:", hidden_states[0, :3, :3])
|
| 330 |
|
| 331 |
hidden_states = self.input_layernorm(hidden_states)
|
| 332 |
|
| 333 |
+
# if print_values:
|
| 334 |
+
# print("Hidden states after RMS norm, before self attention:", hidden_states[0,:3,:3])
|
| 335 |
|
| 336 |
# Self Attention
|
| 337 |
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
|
|
|
| 345 |
print_values=print_values,
|
| 346 |
)
|
| 347 |
|
| 348 |
+
# if print_values:
|
| 349 |
+
# print("Hidden states after self attention:", hidden_states[0,:3,:3])
|
| 350 |
|
| 351 |
hidden_states = residual + hidden_states
|
| 352 |
|