moving past keyvalue dynamically
Browse files- modeling_molmo2.py +15 -24
modeling_molmo2.py
CHANGED
|
@@ -1054,31 +1054,22 @@ class Molmo2TextModel(Molmo2PreTrainedModel):
|
|
| 1054 |
position_embeddings_i = position_embeddings
|
| 1055 |
|
| 1056 |
block_device = next(decoder_block.parameters()).device
|
| 1057 |
-
if hidden_states.device != block_device:
|
| 1058 |
-
hidden_states = hidden_states.to(block_device)
|
| 1059 |
|
| 1060 |
-
|
| 1061 |
-
|
| 1062 |
-
|
| 1063 |
-
|
| 1064 |
-
|
| 1065 |
-
|
| 1066 |
-
|
| 1067 |
-
|
| 1068 |
-
|
| 1069 |
-
|
| 1070 |
-
|
| 1071 |
-
|
| 1072 |
-
|
| 1073 |
-
|
| 1074 |
-
|
| 1075 |
-
return tuple(
|
| 1076 |
-
tuple(p.to(device) for p in layer)
|
| 1077 |
-
for layer in past_key_values
|
| 1078 |
-
)
|
| 1079 |
-
|
| 1080 |
-
past_key_values = move_past_key_values(past_key_values, block_device)
|
| 1081 |
-
position_embeddings_i = position_embeddings_i.to(block_device)
|
| 1082 |
|
| 1083 |
layer_outputs = decoder_block(
|
| 1084 |
hidden_states,
|
|
|
|
| 1054 |
position_embeddings_i = position_embeddings
|
| 1055 |
|
| 1056 |
block_device = next(decoder_block.parameters()).device
|
|
|
|
|
|
|
| 1057 |
|
| 1058 |
+
def move_if_needed(x, device):
|
| 1059 |
+
if x is None:
|
| 1060 |
+
return x
|
| 1061 |
+
if isinstance(x, tuple):
|
| 1062 |
+
return tuple(move_if_needed(xx, device) for xx in x)
|
| 1063 |
+
if hasattr(x, "device") and x.device != device:
|
| 1064 |
+
return x.to(device)
|
| 1065 |
+
return x
|
| 1066 |
+
|
| 1067 |
+
hidden_states = move_if_needed(hidden_states, block_device)
|
| 1068 |
+
causal_mask_mapping = move_if_needed(causal_mask_mapping, block_device)
|
| 1069 |
+
position_ids = move_if_needed(position_ids, block_device)
|
| 1070 |
+
position_embeddings_i = move_if_needed(position_embeddings_i, block_device)
|
| 1071 |
+
cache_position = move_if_needed(cache_position, block_device)
|
| 1072 |
+
past_key_values = move_if_needed(past_key_values, block_device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1073 |
|
| 1074 |
layer_outputs = decoder_block(
|
| 1075 |
hidden_states,
|