| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """Testing suite for the PyTorch Arcee model.""" |
|
|
| import unittest |
|
|
| from pytest import mark |
|
|
| from transformers import AutoTokenizer, is_torch_available |
| from transformers.testing_utils import ( |
| require_flash_attn, |
| require_torch, |
| require_torch_accelerator, |
| slow, |
| ) |
|
|
| from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester |
|
|
|
|
| if is_torch_available(): |
| import torch |
|
|
| from transformers import ( |
| ArceeConfig, |
| ArceeForCausalLM, |
| ArceeForQuestionAnswering, |
| ArceeForSequenceClassification, |
| ArceeForTokenClassification, |
| ArceeModel, |
| ) |
| from transformers.models.arcee.modeling_arcee import ArceeRotaryEmbedding |
|
|
|
|
| class ArceeModelTester(CausalLMModelTester): |
| if is_torch_available(): |
| config_class = ArceeConfig |
| base_model_class = ArceeModel |
| causal_lm_class = ArceeForCausalLM |
| sequence_class = ArceeForSequenceClassification |
| token_class = ArceeForTokenClassification |
|
|
|
|
| @require_torch |
| class ArceeModelTest(CausalLMModelTest, unittest.TestCase): |
| all_model_classes = ( |
| ( |
| ArceeModel, |
| ArceeForCausalLM, |
| ArceeForSequenceClassification, |
| ArceeForQuestionAnswering, |
| ArceeForTokenClassification, |
| ) |
| if is_torch_available() |
| else () |
| ) |
| pipeline_model_mapping = ( |
| { |
| "feature-extraction": ArceeModel, |
| "text-classification": ArceeForSequenceClassification, |
| "text-generation": ArceeForCausalLM, |
| "zero-shot": ArceeForSequenceClassification, |
| "question-answering": ArceeForQuestionAnswering, |
| "token-classification": ArceeForTokenClassification, |
| } |
| if is_torch_available() |
| else {} |
| ) |
| test_headmasking = False |
| test_pruning = False |
| fx_compatible = False |
| model_tester_class = ArceeModelTester |
| rotary_embedding_layer = ArceeRotaryEmbedding |
|
|
| |
| |
| model_split_percents = [0.5, 0.7, 0.8] |
|
|
| |
| _torch_compile_train_cls = ArceeForCausalLM if is_torch_available() else None |
|
|
| def test_arcee_mlp_uses_relu_squared(self): |
| """Test that ArceeMLP uses ReLU² activation instead of SiLU.""" |
| config, _ = self.model_tester.prepare_config_and_inputs_for_common() |
| config.hidden_act = "relu2" |
| model = ArceeModel(config) |
|
|
| |
| mlp = model.layers[0].mlp |
| |
| x = torch.randn(1, 10, config.hidden_size) |
| up_output = mlp.up_proj(x) |
|
|
| |
| expected_activation = up_output * torch.relu(up_output) |
| actual_activation = mlp.act_fn(up_output) |
|
|
| self.assertTrue(torch.allclose(expected_activation, actual_activation, atol=1e-5)) |
|
|
|
|
| @require_torch_accelerator |
| class ArceeIntegrationTest(unittest.TestCase): |
| def tearDown(self): |
| import gc |
|
|
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
| @slow |
| def test_model_from_pretrained(self): |
| |
| |
| config = ArceeConfig() |
| model = ArceeForCausalLM(config) |
| self.assertIsInstance(model, ArceeForCausalLM) |
|
|
| @mark.skip(reason="Model is not currently public - will update test post release") |
| @slow |
| def test_model_generation(self): |
| EXPECTED_TEXT_COMPLETION = ( |
| """Once upon a time,In a village there was a farmer who had three sons. The farmer was very old and he""" |
| ) |
| prompt = "Once upon a time" |
| tokenizer = AutoTokenizer.from_pretrained("arcee-ai/model-id") |
| model = ArceeForCausalLM.from_pretrained("arcee-ai/model-id", device_map="auto") |
| input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device) |
|
|
| generated_ids = model.generate(input_ids, max_new_tokens=20) |
| text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
| self.assertEqual(EXPECTED_TEXT_COMPLETION, text) |
|
|
| @mark.skip(reason="Model is not currently public - will update test post release") |
| @slow |
| @require_flash_attn |
| @mark.flash_attn_test |
| def test_model_generation_flash_attn(self): |
| EXPECTED_TEXT_COMPLETION = ( |
| " the food, the people, and the overall experience. I would definitely recommend this place to others." |
| ) |
| prompt = "This is a nice place. " * 1024 + "I really enjoy the scenery," |
| tokenizer = AutoTokenizer.from_pretrained("arcee-ai/model-id") |
| model = ArceeForCausalLM.from_pretrained( |
| "arcee-ai/model-id", device_map="auto", attn_implementation="flash_attention_2", torch_dtype="auto" |
| ) |
| input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device) |
|
|
| generated_ids = model.generate(input_ids, max_new_tokens=20) |
| text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
| self.assertEqual(EXPECTED_TEXT_COMPLETION, text[len(prompt) :]) |
|
|