| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """Testing suite for the PyTorch Bamba model.""" |
| |
|
| | import inspect |
| | import tempfile |
| | import unittest |
| |
|
| | import pytest |
| | from pytest import mark |
| |
|
| | from transformers import ( |
| | AutoTokenizer, |
| | BambaConfig, |
| | DataCollatorWithFlattening, |
| | is_torch_available, |
| | ) |
| | from transformers.testing_utils import ( |
| | DeviceProperties, |
| | Expectations, |
| | get_device_properties, |
| | require_deterministic_for_xpu, |
| | require_flash_attn, |
| | require_torch, |
| | require_torch_accelerator, |
| | slow, |
| | torch_device, |
| | ) |
| |
|
| | from ...generation.test_utils import GenerationTesterMixin |
| | from ...test_configuration_common import ConfigTester |
| | from ...test_modeling_common import ModelTesterMixin, ids_tensor |
| | from ...test_pipeline_mixin import PipelineTesterMixin |
| |
|
| |
|
| | if is_torch_available(): |
| | import torch |
| |
|
| | from transformers import ( |
| | BambaForCausalLM, |
| | BambaModel, |
| | ) |
| | from transformers.models.bamba.modeling_bamba import HybridMambaAttentionDynamicCache |
| |
|
| |
|
| | class BambaModelTester: |
| | config_class = BambaConfig |
| | if is_torch_available(): |
| | model_class = BambaModel |
| | for_causal_lm_class = BambaForCausalLM |
| |
|
| | def __init__( |
| | self, |
| | parent, |
| | batch_size=13, |
| | seq_length=7, |
| | is_training=True, |
| | use_input_mask=True, |
| | use_labels=True, |
| | vocab_size=99, |
| | hidden_size=32, |
| | num_hidden_layers=2, |
| | num_attention_heads=4, |
| | num_key_value_heads=2, |
| | intermediate_size=64, |
| | hidden_act="silu", |
| | attention_dropout=0.0, |
| | attn_layer_indices=None, |
| | attn_rotary_emb=8, |
| | max_position_embeddings=512, |
| | type_vocab_size=16, |
| | initializer_range=0.02, |
| | num_labels=3, |
| | pad_token_id=0, |
| | mamba_n_groups=1, |
| | mamba_n_heads=16, |
| | mamba_d_state=16, |
| | mamba_d_conv=4, |
| | mamba_expand=2, |
| | mamba_chunk_size=16, |
| | scope=None, |
| | ): |
| | self.parent = parent |
| | self.batch_size = batch_size |
| | self.seq_length = seq_length |
| | self.is_training = is_training |
| | self.use_input_mask = use_input_mask |
| | self.use_labels = use_labels |
| | self.vocab_size = vocab_size |
| | self.hidden_size = hidden_size |
| | self.num_hidden_layers = num_hidden_layers |
| | self.num_attention_heads = num_attention_heads |
| | self.num_key_value_heads = num_key_value_heads |
| | self.intermediate_size = intermediate_size |
| | self.hidden_act = hidden_act |
| | self.attention_dropout = attention_dropout |
| | self.attn_layer_indices = attn_layer_indices |
| | self.attn_rotary_emb = attn_rotary_emb |
| | self.max_position_embeddings = max_position_embeddings |
| | self.type_vocab_size = type_vocab_size |
| | self.initializer_range = initializer_range |
| | self.num_labels = num_labels |
| | self.pad_token_id = pad_token_id |
| | self.scope = scope |
| | self.mamba_n_groups = mamba_n_groups |
| | self.mamba_n_heads = mamba_n_heads |
| | self.mamba_d_state = mamba_d_state |
| | self.mamba_d_conv = mamba_d_conv |
| | self.mamba_expand = mamba_expand |
| | self.mamba_chunk_size = mamba_chunk_size |
| |
|
| | def prepare_config_and_inputs(self): |
| | input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) |
| |
|
| | input_mask = None |
| | if self.use_input_mask: |
| | input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device)) |
| |
|
| | token_labels = None |
| | if self.use_labels: |
| | token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) |
| |
|
| | self._update_layer_configs() |
| | config = self.get_config() |
| |
|
| | return config, input_ids, input_mask, token_labels |
| |
|
| | def prepare_config_and_inputs_for_common(self): |
| | config_and_inputs = self.prepare_config_and_inputs() |
| | ( |
| | config, |
| | input_ids, |
| | input_mask, |
| | token_labels, |
| | ) = config_and_inputs |
| | inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} |
| | return config, inputs_dict |
| |
|
| | def _update_layer_configs(self): |
| | """Configures hidden layers and attn layer indices if they are not set.""" |
| | |
| | if self.num_hidden_layers < 4: |
| | self.num_hidden_layers = 4 |
| |
|
| | if self.attn_layer_indices is None: |
| | d = [x for x in range(2, self.num_hidden_layers) if self.num_hidden_layers % x == 0] |
| | if len(d) == 0: |
| | raise ValueError("num_hidden_layers is prime, cannot automatically set attn_layer_indices.") |
| | d = d[-1] |
| | self.attn_layer_indices = [x + 1 for x in range(0, self.num_hidden_layers, d)] |
| |
|
| | def get_config(self, **kwargs): |
| | return self.config_class( |
| | vocab_size=self.vocab_size, |
| | hidden_size=self.hidden_size, |
| | num_hidden_layers=self.num_hidden_layers, |
| | num_attention_heads=self.num_attention_heads, |
| | num_key_value_heads=self.num_key_value_heads, |
| | intermediate_size=self.intermediate_size, |
| | hidden_act=self.hidden_act, |
| | attention_dropout=self.attention_dropout, |
| | attn_layer_indices=self.attn_layer_indices, |
| | attn_rotary_emb=self.attn_rotary_emb, |
| | max_position_embeddings=self.max_position_embeddings, |
| | initializer_range=self.initializer_range, |
| | pad_token_id=self.pad_token_id, |
| | mamba_n_groups=self.mamba_n_groups, |
| | mamba_n_heads=self.mamba_n_heads, |
| | mamba_d_state=self.mamba_d_state, |
| | mamba_d_conv=self.mamba_d_conv, |
| | mamba_expand=self.mamba_expand, |
| | mamba_chunk_size=self.mamba_chunk_size, |
| | **kwargs, |
| | ) |
| |
|
| | def create_and_check_model( |
| | self, |
| | config, |
| | input_ids, |
| | input_mask, |
| | token_labels, |
| | ): |
| | model = self.model_class(config=config) |
| | model.to(torch_device) |
| | model.eval() |
| | result = model(input_ids, attention_mask=input_mask) |
| | result = model(input_ids) |
| | self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) |
| |
|
| | def create_and_check_for_causal_lm( |
| | self, |
| | config, |
| | input_ids, |
| | input_mask, |
| | token_labels, |
| | ): |
| | model = self.for_causal_lm_class(config=config) |
| | model.to(torch_device) |
| | model.eval() |
| | result = model(input_ids, attention_mask=input_mask, labels=token_labels) |
| | result = model(input_ids, attention_mask=input_mask) |
| | result = model(input_ids, labels=token_labels) |
| | result = model(input_ids) |
| | self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) |
| |
|
| | def create_and_check_decoder_model_past_large_inputs( |
| | self, |
| | config, |
| | input_ids, |
| | input_mask, |
| | token_labels, |
| | ): |
| | |
| | |
| | model = self.for_causal_lm_class(config=config) |
| | model.to(torch_device) |
| | model.eval() |
| |
|
| | |
| | |
| | past_key_values = HybridMambaAttentionDynamicCache( |
| | config, input_ids.shape[0], model.dtype, device=model.device |
| | ) |
| | outputs = model( |
| | input_ids, |
| | attention_mask=input_mask, |
| | past_key_values=past_key_values, |
| | use_cache=True, |
| | ) |
| | past_key_values = outputs.past_key_values |
| |
|
| | |
| | next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) |
| | next_mask = ids_tensor((self.batch_size, 3), vocab_size=2) |
| |
|
| | |
| | next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) |
| | next_attention_mask = torch.cat([input_mask, next_mask], dim=-1) |
| |
|
| | output_from_no_past = model( |
| | next_input_ids, |
| | attention_mask=next_attention_mask, |
| | output_hidden_states=True, |
| | )["hidden_states"][0] |
| | output_from_past = model( |
| | next_tokens, |
| | attention_mask=next_attention_mask, |
| | past_key_values=past_key_values, |
| | output_hidden_states=True, |
| | cache_position=torch.arange( |
| | input_ids.shape[1], input_ids.shape[1] + next_tokens.shape[1], device=model.device |
| | ), |
| | )["hidden_states"][0] |
| |
|
| | |
| | random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() |
| | output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() |
| | output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() |
| |
|
| | self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) |
| |
|
| | |
| | self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) |
| |
|
| |
|
| | @require_torch |
| | class BambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): |
| | model_tester_class = BambaModelTester |
| | all_model_classes = (BambaModel, BambaForCausalLM) if is_torch_available() else () |
| | pipeline_model_mapping = ( |
| | { |
| | "feature-extraction": BambaModel, |
| | "text-generation": BambaForCausalLM, |
| | } |
| | if is_torch_available() |
| | else {} |
| | ) |
| |
|
| | |
| | |
| | model_split_percents = [0.5, 0.7, 0.8] |
| |
|
| | def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): |
| | self.assertIsInstance(past_key_values, HybridMambaAttentionDynamicCache) |
| |
|
| | |
| | num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) |
| | head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) |
| | attention_shape = (batch_size, num_heads, seq_length, head_dim) |
| |
|
| | conv_shape = ( |
| | batch_size, |
| | config.mamba_expand * config.hidden_size + 2 * config.mamba_n_groups * config.mamba_d_state, |
| | config.mamba_d_conv, |
| | ) |
| | ssm_shape = (batch_size, config.mamba_n_heads, config.mamba_d_head, config.mamba_d_state) |
| |
|
| | self.assertTrue(config.num_hidden_layers, len(past_key_values)) |
| |
|
| | for idx in range(len(past_key_values)): |
| | if config.layers_block_type[idx] == "mamba": |
| | self.assertEqual(past_key_values.conv_states[idx].shape, conv_shape) |
| | self.assertEqual(past_key_values.ssm_states[idx].shape, ssm_shape) |
| | else: |
| | self.assertEqual(past_key_values.key_cache[idx].shape, attention_shape) |
| | self.assertEqual(past_key_values.value_cache[idx].shape, attention_shape) |
| |
|
| | def _check_caches_are_equal( |
| | self, cache1: HybridMambaAttentionDynamicCache, cache2: HybridMambaAttentionDynamicCache |
| | ): |
| | if not isinstance(cache1, HybridMambaAttentionDynamicCache) or not isinstance( |
| | cache2, HybridMambaAttentionDynamicCache |
| | ): |
| | raise ValueError("The wrong cache is being used!") |
| |
|
| | if not len(cache1) == len(cache2): |
| | raise ValueError("Both caches do not have the same number of layers.") |
| |
|
| | num_layers = len(cache1) |
| | for idx in range(num_layers): |
| | torch.testing.assert_close(cache1.key_cache[idx], cache2.key_cache[idx]) |
| | torch.testing.assert_close(cache1.value_cache[idx], cache2.value_cache[idx]) |
| | torch.testing.assert_close(cache1.conv_states[idx], cache2.conv_states[idx]) |
| | torch.testing.assert_close(cache1.ssm_states[idx], cache2.ssm_states[idx]) |
| |
|
| | def setUp(self): |
| | self.model_tester = self.model_tester_class(self) |
| | self.config_tester = ConfigTester(self, config_class=self.model_tester.config_class, hidden_size=64) |
| |
|
| | def test_config(self): |
| | self.config_tester.run_common_tests() |
| |
|
| | def test_model(self): |
| | config_and_inputs = self.model_tester.prepare_config_and_inputs() |
| | self.model_tester.create_and_check_model(*config_and_inputs) |
| |
|
| | def test_for_causal_lm(self): |
| | config_and_inputs = self.model_tester.prepare_config_and_inputs() |
| | self.model_tester.create_and_check_for_causal_lm(*config_and_inputs) |
| |
|
| | def test_decoder_model_past_with_large_inputs(self): |
| | config_and_inputs = self.model_tester.prepare_config_and_inputs() |
| | self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) |
| |
|
| | def test_attention_outputs(self): |
| | r""" |
| | Overriding the test_attention_outputs test as the Bamba model outputs attention only for its attention layers |
| | """ |
| | config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() |
| | config.return_dict = True |
| |
|
| | seq_len = getattr(self.model_tester, "seq_length", None) |
| | encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len) |
| | encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length) |
| |
|
| | expected_num_attentions = self.model_tester.num_hidden_layers - len(self.model_tester.attn_layer_indices) |
| |
|
| | for model_class in self.all_model_classes: |
| | inputs_dict["output_attentions"] = True |
| | inputs_dict["output_hidden_states"] = False |
| | config.return_dict = True |
| | model = model_class._from_config(config, attn_implementation="eager") |
| | config = model.config |
| | model.to(torch_device) |
| | model.eval() |
| |
|
| | with torch.no_grad(): |
| | outputs = model(**self._prepare_for_class(inputs_dict, model_class)) |
| | attentions = outputs.attentions |
| | self.assertEqual(len(attentions), expected_num_attentions) |
| |
|
| | |
| | del inputs_dict["output_attentions"] |
| | config.output_attentions = True |
| | model = model_class(config) |
| | model.to(torch_device) |
| | model.eval() |
| | with torch.no_grad(): |
| | outputs = model(**self._prepare_for_class(inputs_dict, model_class)) |
| | attentions = outputs.attentions |
| | self.assertEqual(len(attentions), expected_num_attentions) |
| |
|
| | self.assertListEqual( |
| | list(attentions[0].shape[-3:]), |
| | [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], |
| | ) |
| | out_len = len(outputs) |
| |
|
| | |
| | inputs_dict["output_attentions"] = True |
| | inputs_dict["output_hidden_states"] = True |
| | model = model_class(config) |
| | model.to(torch_device) |
| | model.eval() |
| | with torch.no_grad(): |
| | outputs = model(**self._prepare_for_class(inputs_dict, model_class)) |
| |
|
| | added_hidden_states = 1 |
| | self.assertEqual(out_len + added_hidden_states, len(outputs)) |
| |
|
| | self_attentions = outputs.attentions |
| |
|
| | self.assertEqual(len(self_attentions), expected_num_attentions) |
| | self.assertListEqual( |
| | list(self_attentions[0].shape[-3:]), |
| | [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], |
| | ) |
| |
|
| | def test_batching_equivalence(self): |
| | |
| | orig = self.model_tester.use_input_mask |
| | self.model_tester.use_input_mask = False |
| | super().test_batching_equivalence() |
| | self.model_tester.use_input_mask = orig |
| |
|
| | @pytest.mark.generate |
| | def test_left_padding_compatibility(self): |
| | |
| | unpadded_custom_inputs = {"attention_mask": None} |
| | super().test_left_padding_compatibility(unpadded_custom_inputs=unpadded_custom_inputs) |
| |
|
| | @unittest.skip( |
| | "Bamba requires additionally specifying position_ids, seq_idx, and FlashAttentionKwargs for padding-free training." |
| | ) |
| | def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): |
| | pass |
| |
|
| | @unittest.skip( |
| | "Bamba requires additionally specifying position_ids, seq_idx, and FlashAttentionKwargs for padding-free training." |
| | ) |
| | def test_flash_attention_2_padding_matches_padding_free_with_position_ids_and_fa_kwargs(self): |
| | pass |
| |
|
| | @require_flash_attn |
| | @require_torch_accelerator |
| | @mark.flash_attn_test |
| | @slow |
| | @unittest.skip( |
| | "NotImplementedError: seq_idx support requires fast path support. Please install mamba_ssm and causal_conv1d" |
| | ) |
| | def test_flash_attention_2_padding_matches_padding_free_with_position_ids_seq_idx_and_fa_kwargs(self): |
| | if not self.has_attentions: |
| | self.skipTest(reason="Model architecture does not support attentions") |
| |
|
| | max_new_tokens = 30 |
| |
|
| | for model_class in self.all_generative_model_classes: |
| | if not model_class._supports_flash_attn: |
| | self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") |
| |
|
| | config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() |
| | if 0 not in inputs_dict.get("attention_mask", []) or "attention_mask" not in inputs_dict: |
| | self.skipTest("Model dummy inputs should contain padding in their attention mask") |
| |
|
| | dummy_input = inputs_dict[model_class.main_input_name] |
| | if dummy_input.dtype in [torch.float32, torch.bfloat16]: |
| | dummy_input = dummy_input.to(torch.float16) |
| |
|
| | |
| | if hasattr(config, "max_position_embeddings"): |
| | config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 |
| |
|
| | model = model_class(config) |
| | if "position_ids" not in inspect.signature(model.forward).parameters: |
| | self.skipTest("Model does not support position_ids") |
| |
|
| | with tempfile.TemporaryDirectory() as tmpdirname: |
| | model.save_pretrained(tmpdirname) |
| |
|
| | |
| | if 0 in inputs_dict["attention_mask"][:, -1]: |
| | inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1) |
| | dummy_attention_mask = inputs_dict["attention_mask"] |
| | inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id |
| | |
| | |
| | labels = inputs_dict["input_ids"].clone() |
| | |
| | labels[~dummy_attention_mask.bool()] = -100 |
| | |
| | first_nonneg_idx = (labels >= 0).int().argmax(dim=1) |
| | labels[torch.arange(labels.size(0), device=labels.device), first_nonneg_idx] = -100 |
| | inputs_dict["labels"] = labels |
| |
|
| | model = ( |
| | model_class.from_pretrained( |
| | tmpdirname, |
| | dtype=torch.float16, |
| | attn_implementation="flash_attention_2", |
| | ) |
| | .to(torch_device) |
| | .eval() |
| | ) |
| |
|
| | |
| | features = [ |
| | {"input_ids": i[a.bool()].tolist()} |
| | for i, a in zip(inputs_dict["input_ids"], inputs_dict["attention_mask"]) |
| | ] |
| |
|
| | |
| | data_collator = DataCollatorWithFlattening( |
| | return_tensors="pt", return_seq_idx=True, return_flash_attn_kwargs=True |
| | ) |
| | batch = data_collator(features) |
| | batch_accelerator = {k: t.to(torch_device) if torch.is_tensor(t) else t for k, t in batch.items()} |
| |
|
| | res_padded = model(**inputs_dict) |
| | res_padfree = model(**batch_accelerator) |
| |
|
| | logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()] |
| | logits_padfree = res_padfree.logits[0] |
| |
|
| | torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0) |
| | |
| | tol = torch.finfo(torch.float16).eps |
| | torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol) |
| |
|
| | loss_padded = res_padded.loss |
| | loss_padfree = res_padfree.loss |
| | torch.testing.assert_close(loss_padded, loss_padfree) |
| |
|
| |
|
| | @slow |
| | @require_torch |
| | @require_torch_accelerator |
| | class BambaModelIntegrationTest(unittest.TestCase): |
| | model = None |
| | tokenizer = None |
| | |
| | |
| | device_properties: DeviceProperties = (None, None, None) |
| |
|
| | @classmethod |
| | def setUpClass(cls): |
| | model_id = "ibm-fms/Bamba-9B" |
| | cls.model = BambaForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16) |
| | cls.tokenizer = AutoTokenizer.from_pretrained(model_id) |
| |
|
| | |
| | cls.tokenizer.pad_token_id = cls.model.config.pad_token_id |
| | cls.tokenizer.padding_side = "left" |
| |
|
| | cls.device_properties = get_device_properties() |
| |
|
| | def test_simple_generate(self): |
| | |
| | expectations = Expectations( |
| | { |
| | ("cuda", 8): "<|begin_of_text|>Hey how are you doing on this lovely evening? I hope you are all having a good time.", |
| | ("rocm", 9): "<|begin_of_text|>Hey how are you doing on this lovely evening? I hope you are doing well. I am here", |
| | ("xpu", 3): "<|begin_of_text|>Hey how are you doing on this lovely evening? I hope you are all doing well. I am", |
| | } |
| | ) |
| | |
| |
|
| | self.model.to(torch_device) |
| |
|
| | input_ids = self.tokenizer("Hey how are you doing on this lovely evening?", return_tensors="pt")[ |
| | "input_ids" |
| | ].to(torch_device) |
| | out = self.model.generate(input_ids, do_sample=False, max_new_tokens=10) |
| | output_sentence = self.tokenizer.decode(out[0, :]) |
| | expected = expectations.get_expectation() |
| | self.assertEqual(output_sentence, expected) |
| |
|
| | |
| | if self.device_properties[0] == "cuda" and self.device_properties[1] == 8: |
| | with torch.no_grad(): |
| | logits = self.model(input_ids=input_ids, logits_to_keep=40).logits |
| |
|
| | EXPECTED_LOGITS_NO_GRAD = torch.tensor( |
| | [ |
| | 149., 142., 146., 142., 143., 144., 142., 145., |
| | 142., 146., 144., 146., 147., 147., 148., 145., |
| | 147., 145., 145., 145., 145., 144., 144., 144., |
| | 144., 145., 147., 146., 144., 144., 148., 147., |
| | 148., 147., 147., 147., 146., 146., 148., 148. |
| | ], dtype=torch.bfloat16) |
| |
|
| | torch.testing.assert_close(logits[0, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD, rtol=1e-3, atol=1) |
| |
|
| | @require_deterministic_for_xpu |
| | def test_simple_batched_generate_with_padding(self): |
| | |
| | |
| | |
| | |
| | |
| | EXPECTED_TEXTS = Expectations( |
| | { |
| | ("cuda", 7): [], |
| | ("cuda", 8): [ |
| | "<|begin_of_text|>Hey how are you doing on this lovely evening? I hope you are doing well. I am here", |
| | "!!!<|begin_of_text|>I am late! I need to get to work! I have to get to the", |
| | ], |
| | ("rocm", 9): [ |
| | "<|begin_of_text|>Hey how are you doing on this lovely evening? I hope you are doing well. I am here", |
| | "!!!<|begin_of_text|>I am late! I need to be at the airport in 20 minutes! I", |
| | ], |
| | ("xpu", 3): [ |
| | "<|begin_of_text|>Hey how are you doing on this lovely evening? I hope you are all doing well. I am", |
| | "!!!<|begin_of_text|>I am late! I need to get to work! I have to get to the", |
| | ], |
| | } |
| | ) |
| | |
| | EXPECTED_TEXT = EXPECTED_TEXTS.get_expectation() |
| |
|
| | self.model.to(torch_device) |
| |
|
| | inputs = self.tokenizer( |
| | ["Hey how are you doing on this lovely evening?", "I am late! I need to"], |
| | padding=True, |
| | return_tensors="pt", |
| | ).to(torch_device) |
| | out = self.model.generate(**inputs, do_sample=False, max_new_tokens=10) |
| | output_sentences = self.tokenizer.batch_decode(out) |
| | self.assertEqual(output_sentences[0], EXPECTED_TEXT[0]) |
| | self.assertEqual(output_sentences[1], EXPECTED_TEXT[1]) |
| |
|
| | |
| | if self.device_properties[0] == "cuda" and self.device_properties[1] == 8: |
| | with torch.no_grad(): |
| | logits = self.model(input_ids=inputs["input_ids"]).logits |
| |
|
| | EXPECTED_LOGITS_NO_GRAD_0 = torch.tensor( |
| | [ |
| | 149., 142., 146., 142., 143., 144., 142., 145., |
| | 142., 146., 144., 146., 147., 147., 148., 145., |
| | 147., 145., 145., 145., 145., 144., 144., 144., |
| | 144., 145., 147., 146., 144., 144., 148., 147., |
| | 148., 147., 147., 147., 146., 146., 148., 148. |
| | ], dtype=torch.bfloat16) |
| |
|
| | EXPECTED_LOGITS_NO_GRAD_1 = torch.tensor( |
| | [ |
| | 182., 178., 177., 174., 176., 176., 178., 178., |
| | 177., 179., 176., 183., 180., 182., 179., 174., |
| | 178., 176., 176., 175., 175., 175., 174., 173., |
| | 174., 182., 180., 176., 177., 177., 180., 176., |
| | 178., 177., 177., 175., 176., 177., 175., 177. |
| | ], dtype=torch.bfloat16) |
| |
|
| | torch.testing.assert_close(logits[0, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD_0, rtol=1e-3, atol=1) |
| | torch.testing.assert_close(logits[1, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD_1, rtol=1e-3, atol=1) |
| |
|