File size: 6,082 Bytes
a9bd396
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch

from transformers import AutoModelForCausalLM, set_seed
from transformers.generation.configuration_utils import GenerationConfig
from transformers.integrations.executorch import (
    TorchExportableModuleForDecoderOnlyLM,
    TorchExportableModuleWithHybridCache,
    TorchExportableModuleWithStaticCache,
)
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_3
from transformers.testing_utils import require_torch


@require_torch
class ExecutorchTest(unittest.TestCase):
    def setUp(self):
        if not is_torch_greater_or_equal_than_2_3:
            self.skipTest("torch >= 2.3 is required")

        set_seed(0)
        self.model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
        self.model.eval()

        # Create generation config with static cache for the model
        self.model.generation_config = GenerationConfig(
            use_cache=True,
            cache_implementation="static",
            cache_config={"batch_size": 1, "max_cache_len": 32, "device": "cpu"},
        )

        self.input_ids = torch.tensor([[1, 2, 3]], dtype=torch.long)
        self.inputs_embeds = torch.randn(1, 3, self.model.config.hidden_size)
        self.cache_position = torch.arange(3, dtype=torch.long)

    def test_static_cache_module_forward(self):
        """Test TorchExportableModuleWithStaticCache forward with both input types"""
        generation_config = GenerationConfig(
            use_cache=True,
            cache_implementation="static",
            cache_config={"batch_size": 1, "max_cache_len": 32, "device": "cpu"},
        )

        # Set generation config on model
        self.model.generation_config = generation_config
        module = TorchExportableModuleWithStaticCache(self.model)

        # Test with input_ids
        eager_output_ids = self.model(input_ids=self.input_ids, use_cache=False).logits
        wrapped_output_ids = module.forward(input_ids=self.input_ids, cache_position=self.cache_position)
        torch.testing.assert_close(eager_output_ids, wrapped_output_ids, atol=1e-4, rtol=1e-4)

        # Test with inputs_embeds
        eager_output_embeds = self.model(inputs_embeds=self.inputs_embeds, use_cache=False).logits
        wrapped_output_embeds = module.forward(inputs_embeds=self.inputs_embeds, cache_position=self.cache_position)
        torch.testing.assert_close(eager_output_embeds, wrapped_output_embeds, atol=1e-4, rtol=1e-4)

    def test_hybrid_cache_module_forward(self):
        """Test TorchExportableModuleWithHybridCache forward with both input types"""
        config = self.model.config
        config.sliding_window = 16
        config.layer_types = ["full_attention"] * config.num_hidden_layers

        generation_config = GenerationConfig(
            use_cache=True,
            cache_implementation="hybrid",
            cache_config={"batch_size": 1, "max_cache_len": 32, "device": "cpu"},
        )

        # Set generation config on model
        self.model.generation_config = generation_config
        module = TorchExportableModuleWithHybridCache(self.model)

        # Test with input_ids
        eager_output_ids = self.model(input_ids=self.input_ids, use_cache=False).logits
        wrapped_output_ids = module.forward(input_ids=self.input_ids, cache_position=self.cache_position)
        torch.testing.assert_close(eager_output_ids, wrapped_output_ids, atol=1e-4, rtol=1e-4)

        # Test with inputs_embeds
        eager_output_embeds = self.model(inputs_embeds=self.inputs_embeds, use_cache=False).logits
        wrapped_output_embeds = module.forward(inputs_embeds=self.inputs_embeds, cache_position=self.cache_position)
        torch.testing.assert_close(eager_output_embeds, wrapped_output_embeds, atol=1e-4, rtol=1e-4)

    def test_decoder_only_lm_export_validation(self):
        """Test TorchExportableModuleForDecoderOnlyLM export validation"""
        module = TorchExportableModuleForDecoderOnlyLM(self.model)

        # Should fail with both input_ids and inputs_embeds
        with self.assertRaises(ValueError):
            module.export(input_ids=self.input_ids, inputs_embeds=self.inputs_embeds)

        # Should fail with neither
        with self.assertRaises(ValueError):
            module.export()

    def test_decoder_only_lm_export(self):
        """Test TorchExportableModuleForDecoderOnlyLM export with both input types"""
        module = TorchExportableModuleForDecoderOnlyLM(self.model)

        # Test export with input_ids
        exported_program_ids = module.export(input_ids=self.input_ids, cache_position=self.cache_position)
        eager_output_ids = self.model(input_ids=self.input_ids, use_cache=False).logits
        exported_output_ids = exported_program_ids.module()(
            input_ids=self.input_ids, cache_position=self.cache_position
        )
        torch.testing.assert_close(eager_output_ids, exported_output_ids, atol=1e-4, rtol=1e-4)

        # Test export with inputs_embeds
        exported_program_embeds = module.export(inputs_embeds=self.inputs_embeds, cache_position=self.cache_position)
        eager_output_embeds = self.model(inputs_embeds=self.inputs_embeds, use_cache=False).logits
        exported_output_embeds = exported_program_embeds.module()(
            inputs_embeds=self.inputs_embeds, cache_position=self.cache_position
        )
        torch.testing.assert_close(eager_output_embeds, exported_output_embeds, atol=1e-4, rtol=1e-4)