Add print statements
Browse files- modeling_cogvlm.py +30 -0
modeling_cogvlm.py
CHANGED
|
@@ -225,6 +225,7 @@ class VisionExpertAttention(nn.Module):
|
|
| 225 |
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 226 |
output_attentions: bool = False,
|
| 227 |
use_cache: bool = False,
|
|
|
|
| 228 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 229 |
bsz, q_len, _ = hidden_states.size()
|
| 230 |
vision_token_mask, language_token_mask = get_expert_mask(token_type_ids)
|
|
@@ -240,6 +241,34 @@ class VisionExpertAttention(nn.Module):
|
|
| 240 |
key_states = self._transpose_for_scores(key_states) # B, H, L, HD
|
| 241 |
value_states = self._transpose_for_scores(value_states) # B, H, L, HD
|
| 242 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
kv_seq_len = key_states.shape[-2]
|
| 244 |
if past_key_value is not None:
|
| 245 |
kv_seq_len += past_key_value[0].shape[-2]
|
|
@@ -308,6 +337,7 @@ class CogVLMDecoderLayer(nn.Module):
|
|
| 308 |
past_key_value=past_key_value,
|
| 309 |
output_attentions=output_attentions,
|
| 310 |
use_cache=use_cache,
|
|
|
|
| 311 |
)
|
| 312 |
|
| 313 |
if print_values:
|
|
|
|
| 225 |
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 226 |
output_attentions: bool = False,
|
| 227 |
use_cache: bool = False,
|
| 228 |
+
print_values: bool = False,
|
| 229 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 230 |
bsz, q_len, _ = hidden_states.size()
|
| 231 |
vision_token_mask, language_token_mask = get_expert_mask(token_type_ids)
|
|
|
|
| 241 |
key_states = self._transpose_for_scores(key_states) # B, H, L, HD
|
| 242 |
value_states = self._transpose_for_scores(value_states) # B, H, L, HD
|
| 243 |
|
| 244 |
+
torch.save(query_states, "query_states.pt")
|
| 245 |
+
torch.save(key_states, "key_states.pt")
|
| 246 |
+
torch.save(value_states, "value_states.pt")
|
| 247 |
+
|
| 248 |
+
from huggingface_hub import HfApi
|
| 249 |
+
|
| 250 |
+
api = HfApi()
|
| 251 |
+
api.upload_file(
|
| 252 |
+
path_or_fileobj="query_states.pt",
|
| 253 |
+
path_in_repo="query_states.pt",
|
| 254 |
+
repo_id="nielsr/test-cogvlm",
|
| 255 |
+
repo_type="dataset",
|
| 256 |
+
)
|
| 257 |
+
api = HfApi()
|
| 258 |
+
api.upload_file(
|
| 259 |
+
path_or_fileobj="key_states.pt",
|
| 260 |
+
path_in_repo="key_states.pt",
|
| 261 |
+
repo_id="nielsr/test-cogvlm",
|
| 262 |
+
repo_type="dataset",
|
| 263 |
+
)
|
| 264 |
+
api = HfApi()
|
| 265 |
+
api.upload_file(
|
| 266 |
+
path_or_fileobj="value_states.pt",
|
| 267 |
+
path_in_repo="value_states.pt",
|
| 268 |
+
repo_id="nielsr/test-cogvlm",
|
| 269 |
+
repo_type="dataset",
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
kv_seq_len = key_states.shape[-2]
|
| 273 |
if past_key_value is not None:
|
| 274 |
kv_seq_len += past_key_value[0].shape[-2]
|
|
|
|
| 337 |
past_key_value=past_key_value,
|
| 338 |
output_attentions=output_attentions,
|
| 339 |
use_cache=use_cache,
|
| 340 |
+
print_values=print_values,
|
| 341 |
)
|
| 342 |
|
| 343 |
if print_values:
|