| |
| |
| |
| |
| @@ -2232,7 +2232,7 @@ def _prefetch_next_layer(self, layer_idx: int) -> None: |
| |
| def _prefetch_layer_in_context(self, layer_idx: int) -> None: |
| """Performs the actual copy of the layer to device cache.""" |
| - if len(self.key_cache) >= layer_idx: |
| + if len(self.key_cache) > layer_idx: |
| self.device_key_cache[self.active_device_layer].copy_(self.key_cache[layer_idx], non_blocking=True) |
| self.device_value_cache[self.active_device_layer].copy_(self.value_cache[layer_idx], non_blocking=True) |
| # The layer was not yet initialized |
| |
| |
| |
| |
| @@ -11,7 +11,6 @@ |
| # specific language governing permissions and limitations under the License. |
| |
| import logging |
| -from contextlib import contextmanager |
| from typing import Callable, Optional |
| |
| import torch |
| @@ -110,14 +109,13 @@ def export( |
| example_input_ids = input_ids if input_ids is not None else torch.tensor([[1]], dtype=torch.long) |
| example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long) |
| |
| - with patch_mask_interface(): |
| - exported_program = torch.export.export( |
| - self.model, |
| - args=(example_input_ids, example_cache_position), |
| - kwargs={}, |
| - dynamic_shapes=dynamic_shapes, |
| - strict=strict if strict is not None else True, |
| - ) |
| + exported_program = torch.export.export( |
| + self.model, |
| + args=(example_input_ids, example_cache_position), |
| + kwargs={}, |
| + dynamic_shapes=dynamic_shapes, |
| + strict=strict if strict is not None else True, |
| + ) |
| return exported_program |
| |
| @staticmethod |
| @@ -456,24 +454,6 @@ def forward( |
| return outputs.logits |
| |
| |
| -@contextmanager |
| -def patch_mask_interface(): |
| - """ |
| - Context manager to locally use a simple dict instead of `AttentionMaskInterface`, as otherwise export will fail |
| - with `strict=True` due to dynamo skip rules, i.e. `torch._dynamo.exc.Unsupported: 'inline in skipfiles: |
| - Mapping.__contains__ | __contains__, skipped according trace_rules.lookup SKIP_DIRS'`. |
| - Note that this seem to be an issue only for python<3.11. |
| - """ |
| - import transformers |
| - |
| - original = transformers.masking_utils.ALL_MASK_ATTENTION_FUNCTIONS |
| - transformers.masking_utils.ALL_MASK_ATTENTION_FUNCTIONS = ALL_MASK_ATTENTION_FUNCTIONS._global_mapping |
| - try: |
| - yield |
| - finally: |
| - transformers.masking_utils.ALL_MASK_ATTENTION_FUNCTIONS = original |
| - |
| - |
| def convert_and_export_with_cache( |
| model: PreTrainedModel, |
| example_input_ids: Optional[torch.Tensor] = None, |
| @@ -515,14 +495,13 @@ def convert_and_export_with_cache( |
| ) |
| |
| if is_torch_greater_or_equal("2.6.0"): |
| - with patch_mask_interface(): |
| - exported_program = torch.export.export( |
| - TorchExportableModuleWithStaticCache(model), |
| - args=(example_input_ids, example_cache_position), |
| - kwargs={}, |
| - dynamic_shapes=dynamic_shapes, |
| - strict=strict if strict is not None else True, |
| - ) |
| + exported_program = torch.export.export( |
| + TorchExportableModuleWithStaticCache(model), |
| + args=(example_input_ids, example_cache_position), |
| + kwargs={}, |
| + dynamic_shapes=dynamic_shapes, |
| + strict=strict if strict is not None else True, |
| + ) |
| else: |
| if dynamic_shapes is not None: |
| logging.warning( |
| @@ -534,14 +513,13 @@ def convert_and_export_with_cache( |
| # |
| # Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal |
| # export API and pre_dispatch=False. Switch to use the public API once the issue is included in 2.5 release. |
| - with patch_mask_interface(): |
| - exported_program = torch.export._trace._export( |
| - TorchExportableModuleWithStaticCache(model), |
| - args=(example_input_ids,), |
| - kwargs={"cache_position": example_cache_position}, |
| - pre_dispatch=False, |
| - strict=True, |
| - ) |
| + exported_program = torch.export._trace._export( |
| + TorchExportableModuleWithStaticCache(model), |
| + args=(example_input_ids,), |
| + kwargs={"cache_position": example_cache_position}, |
| + pre_dispatch=False, |
| + strict=True, |
| + ) |
| return exported_program |
| |
| |
| @@ -634,10 +612,9 @@ def _export_encoder(self, encoder_input_ids): |
| |
| # Export the encoder |
| with torch.no_grad(): |
| - with patch_mask_interface(): |
| - exported_encoder = torch.export.export( |
| - wrapped_encoder, (encoder_input_ids,), dynamic_shapes={"input_ids": {1: seq_len_dim}}, strict=True |
| - ) |
| + exported_encoder = torch.export.export( |
| + wrapped_encoder, (encoder_input_ids,), dynamic_shapes={"input_ids": {1: seq_len_dim}}, strict=True |
| + ) |
| |
| return exported_encoder |
| |
| @@ -657,17 +634,16 @@ def _export_decoder(self, decoder_input_ids, encoder_hidden_states, cache_positi |
| |
| # Export the decoder |
| with torch.no_grad(): |
| - with patch_mask_interface(): |
| - exported_decoder = torch.export.export( |
| - wrapped_decoder, |
| - (decoder_input_ids, encoder_hidden_states, cache_position), |
| - dynamic_shapes={ |
| - "decoder_input_ids": None, |
| - "encoder_hidden_states": {1: encoder_seq_len_dim}, |
| - "cache_position": None, |
| - }, |
| - strict=True, |
| - ) |
| + exported_decoder = torch.export.export( |
| + wrapped_decoder, |
| + (decoder_input_ids, encoder_hidden_states, cache_position), |
| + dynamic_shapes={ |
| + "decoder_input_ids": None, |
| + "encoder_hidden_states": {1: encoder_seq_len_dim}, |
| + "cache_position": None, |
| + }, |
| + strict=True, |
| + ) |
| |
| return exported_decoder |
| |
| |
| |
| |
| |
| @@ -623,7 +623,11 @@ def _preprocess_mask_arguments( |
| return True, attention_mask, None, None |
| |
| # For TGI/vLLM backends, or other custom attention without equivalent mask creation: we don't need a mask! |
| - if config._attn_implementation not in ALL_MASK_ATTENTION_FUNCTIONS: |
| + # Note: it's not ideal to check the `_global_mapping` attribute instead of the object itself, however otherwise |
| + # full graph dynamo tracing (i.e. torch.export or compile with `fullgraph=True`) will fail on Python<3.11 |
| + # with `torch._dynamo.exc.Unsupported: 'inline in skipfiles:Mapping.__contains__ | __contains__, skipped |
| + # according trace_rules.lookup SKIP_DIRS'` -- can be removed when we require Python>=3.11 |
| + if config._attn_implementation not in ALL_MASK_ATTENTION_FUNCTIONS._global_mapping: |
| return True, None, None, None |
| |
| # Move the mask to correct device, and potentially switch dtype for efficiency |
| |
| |
| |
| |
| @@ -232,8 +232,8 @@ def test_batched_small_model_logits(self): |
| |
| EXPECTED_LOGITS = torch.Tensor( |
| [ |
| - [[0.0000, 0.1866, -0.1997], [0.0000, -0.0736, 0.1785], [0.0000, -0.1965, -0.0569]], |
| - [[0.0000, -0.0302, 0.1488], [0.0000, -0.0402, 0.1351], [0.0000, -0.0341, 0.1116]], |
| + [[0.0000, 0.0285, 0.0322], [0.0000, 0.0011, 0.1105], [0.0000, -0.0018, -0.1019]], |
| + [[0.0000, 0.1080, 0.0454], [0.0000, -0.1808, -0.1553], [0.0000, 0.0452, 0.0369]], |
| ] |
| ).to(device=torch_device, dtype=torch.float16) |
| |
| @@ -251,4 +251,4 @@ def test_batched_small_model_logits(self): |
| output = model(**inputs) |
| |
| logits = output.logits |
| - torch.testing.assert_close(EXPECTED_LOGITS, logits[:, :3, :3], rtol=1e-3, atol=1e-3) |
| + torch.testing.assert_close(EXPECTED_LOGITS, logits[:, -3:, :3], rtol=1e-3, atol=1e-3) |
| |
| |
| |
| |
| @@ -150,7 +150,6 @@ class CsmForConditionalGenerationTest(ModelTesterMixin, GenerationTesterMixin, u |
| test_headmasking = False |
| test_resize_embeddings = False |
| test_resize_embeddings_untied = False |
| - test_torch_exportable = True |
| |
| def setUp(self): |
| self.model_tester = CsmModelTester(self) |
| |
| |
| |
| |
| @@ -402,24 +402,12 @@ def test_small_model_logits_batched(self): |
| # |
| # Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s, |
| # considering differences in hardware processing and potential deviations in generated text. |
| - EXPECTED_LOGITS_LEFT = { |
| - 7: torch.Tensor( |
| - [[0.1904, 0.0500, 0.7187], [0.1933, 0.0515, 0.7187], [0.2001, 0.0559, 0.7148]], |
| - ).to(torch_device), |
| - 8: torch.Tensor([[0.1914, 0.0508, 0.7188], [0.1953, 0.0510, 0.7227], [0.1973, 0.0562, 0.7148]]).to( |
| - torch_device |
| - ), |
| - 9: torch.Tensor([[0.1904, 0.0513, 0.7227], [0.1943, 0.0518, 0.7227], [0.1982, 0.0557, 0.7148]]).to( |
| - torch_device |
| - ), |
| - } |
| - |
| EXPECTED_LOGITS_LEFT_UNPADDED = { |
| 7: torch.Tensor( |
| [[0.2236, 0.5195, -0.3828], [0.8203, -0.2275, 0.6054], [0.2656, -0.7070, 0.2460]], |
| ).to(torch_device), |
| - 8: torch.Tensor([[0.2217, 0.5195, -0.3828], [0.8203, -0.2295, 0.6055], [0.2676, -0.7109, 0.2461]]).to( |
| - torch_device |
| + 8: torch.Tensor([[0.2207, 0.5234, -0.3828], [0.8203, -0.2285, 0.6055], [0.2656, -0.7109, 0.2451]]).to( |
| + torch_device, |
| ), |
| 9: torch.Tensor([[0.2236, 0.5195, -0.3828], [0.8203, -0.2285, 0.6055], [0.2637, -0.7109, 0.2451]]).to( |
| torch_device |
| @@ -430,8 +418,8 @@ def test_small_model_logits_batched(self): |
| 7: torch.Tensor([[0.2167, 0.1269, -0.1640], [-0.3496, 0.2988, -1.0312], [0.0688, 0.7929, 0.8007]]).to( |
| torch_device |
| ), |
| - 8: torch.Tensor([[0.2178, 0.1260, -0.1621], [-0.3496, 0.2988, -1.0312], [0.0693, 0.7930, 0.8008]]).to( |
| - torch_device |
| + 8: torch.Tensor([[0.2178, 0.1270, -0.1621], [-0.3496, 0.3008, -1.0312], [0.0693, 0.7930, 0.7969]]).to( |
| + torch_device, |
| ), |
| 9: torch.Tensor([[0.2197, 0.1250, -0.1611], [-0.3516, 0.3008, -1.0312], [0.0684, 0.7930, 0.8008]]).to( |
| torch_device |
| @@ -442,9 +430,6 @@ def test_small_model_logits_batched(self): |
| logits = model(dummy_input, attention_mask=attention_mask).logits |
| logits = logits.float() |
| |
| - torch.testing.assert_close( |
| - logits[0, :3, :3], EXPECTED_LOGITS_LEFT[self.cuda_compute_capability_major_version], atol=1e-3, rtol=1e-3 |
| - ) |
| torch.testing.assert_close( |
| logits[0, -3:, -3:], |
| EXPECTED_LOGITS_LEFT_UNPADDED[self.cuda_compute_capability_major_version], |
| |
| |
| |
| |
| @@ -4461,6 +4461,7 @@ def test_torch_compile_for_training(self): |
| del loss |
| |
| model = torch.compile(model, fullgraph=True, mode="reduce-overhead") |
| + |
| # forward compilation |
| set_seed(42) |
| loss = model(**inputs).loss |
|
|