sunjuice commited on
Commit
43189e0
·
1 Parent(s): c969eb1

moving past keyvalue dynamically

Browse files
Files changed (1) hide show
  1. modeling_molmo2.py +15 -24
modeling_molmo2.py CHANGED
@@ -1054,31 +1054,22 @@ class Molmo2TextModel(Molmo2PreTrainedModel):
1054
  position_embeddings_i = position_embeddings
1055
 
1056
  block_device = next(decoder_block.parameters()).device
1057
- if hidden_states.device != block_device:
1058
- hidden_states = hidden_states.to(block_device)
1059
 
1060
- if causal_mask_mapping is not None and causal_mask_mapping.device != block_device:
1061
- causal_mask_mapping = causal_mask_mapping.to(block_device)
1062
-
1063
- if position_ids is not None and position_ids.device != block_device:
1064
- position_ids = position_ids.to(block_device)
1065
-
1066
- if position_embeddings_i is not None and position_embeddings_i.device != block_device:
1067
- position_embeddings_i = position_embeddings_i.to(block_device)
1068
-
1069
- if cache_position is not None and cache_position.device != block_device:
1070
- cache_position = cache_position.to(block_device)
1071
-
1072
- def move_past_key_values(past_key_values, device):
1073
- if past_key_values is None:
1074
- return None
1075
- return tuple(
1076
- tuple(p.to(device) for p in layer)
1077
- for layer in past_key_values
1078
- )
1079
-
1080
- past_key_values = move_past_key_values(past_key_values, block_device)
1081
- position_embeddings_i = position_embeddings_i.to(block_device)
1082
 
1083
  layer_outputs = decoder_block(
1084
  hidden_states,
 
1054
  position_embeddings_i = position_embeddings
1055
 
1056
  block_device = next(decoder_block.parameters()).device
 
 
1057
 
1058
+ def move_if_needed(x, device):
1059
+ if x is None:
1060
+ return x
1061
+ if isinstance(x, tuple):
1062
+ return tuple(move_if_needed(xx, device) for xx in x)
1063
+ if hasattr(x, "device") and x.device != device:
1064
+ return x.to(device)
1065
+ return x
1066
+
1067
+ hidden_states = move_if_needed(hidden_states, block_device)
1068
+ causal_mask_mapping = move_if_needed(causal_mask_mapping, block_device)
1069
+ position_ids = move_if_needed(position_ids, block_device)
1070
+ position_embeddings_i = move_if_needed(position_embeddings_i, block_device)
1071
+ cache_position = move_if_needed(cache_position, block_device)
1072
+ past_key_values = move_if_needed(past_key_values, block_device)
 
 
 
 
 
 
 
1073
 
1074
  layer_outputs = decoder_block(
1075
  hidden_states,