File size: 4,532 Bytes
29658b2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 | import os
import shutil
import tempfile
import unittest
from unittest.mock import patch
import torch
from transformers import LlamaConfig
from specforge.modeling.draft.llama3_eagle import (
LlamaAttention,
LlamaForCausalLMEagle3,
LlamaMLP,
LlamaRMSNorm,
)
# from model_module import LlamaForCausalLMEagle3
class TestLlamaForCausalLMEagle3Loading(unittest.TestCase):
def setUp(self):
"""Set up the test environment before each test."""
self.temp_dir = tempfile.mkdtemp()
config_dict = {
"architectures": ["LlamaForCausalLM"],
"bos_token_id": 128000,
"eos_token_id": 128001,
"hidden_act": "silu",
"hidden_size": 4096,
"initializer_range": 0.02,
"intermediate_size": 14336,
"max_position_embeddings": 2048,
"model_type": "llama",
"num_attention_heads": 32,
"num_key_value_heads": 8,
"num_hidden_layers": 1,
"pad_token_id": 0,
"rms_norm_eps": 1e-05,
"tie_word_embeddings": False,
"torch_dtype": "float16",
"transformers_version": "4.28.1",
"use_cache": True,
"vocab_size": 128256,
"draft_vocab_size": 32000,
}
self.config = LlamaConfig(**config_dict)
def tearDown(self):
shutil.rmtree(self.temp_dir)
def test_model_initialization(self):
model = LlamaForCausalLMEagle3(self.config)
self.assertIsInstance(model.midlayer.self_attn, LlamaAttention)
self.assertIsInstance(model.midlayer.mlp, LlamaMLP)
self.assertIsInstance(model.midlayer.hidden_norm, LlamaRMSNorm)
self.assertIsInstance(model.midlayer.input_layernorm, LlamaRMSNorm)
self.assertIsInstance(model.midlayer.post_attention_layernorm, LlamaRMSNorm)
self.assertEqual(model.midlayer.hidden_size, self.config.hidden_size)
def test_save_pretrained(self):
"""Test the model's save_pretrained functionality."""
model = LlamaForCausalLMEagle3(self.config)
self.config.save_pretrained(self.temp_dir)
model_path = os.path.join(self.temp_dir, "pytorch_model.bin")
torch.save(model.state_dict(), model_path)
self.assertTrue(os.path.exists(os.path.join(self.temp_dir, "config.json")))
self.assertTrue(os.path.exists(model_path))
@patch("transformers.modeling_utils.PreTrainedModel.from_pretrained")
def test_from_pretrained_mock(self, mock_from_pretrained):
"""mock"""
mock_model = LlamaForCausalLMEagle3(self.config)
mock_from_pretrained.return_value = mock_model
loaded_model = LlamaForCausalLMEagle3.from_pretrained(self.temp_dir)
mock_from_pretrained.assert_called_once_with(self.temp_dir)
self.assertIsInstance(loaded_model, LlamaForCausalLMEagle3)
def test_model_forward_pass(self):
"""forward"""
model = LlamaForCausalLMEagle3(self.config)
model.eval()
batch_size = 2
seq_len = 10
input_emb = torch.randn(batch_size, seq_len, self.config.hidden_size)
hidden_states = torch.randn(batch_size, seq_len, self.config.hidden_size * 3)
attention_mask = torch.ones(batch_size, seq_len)
with torch.no_grad():
outputs = model(
inputs_embeds=input_emb,
hidden_states=hidden_states,
attention_mask=attention_mask,
)
self.assertEqual(outputs.shape, (batch_size, seq_len, self.config.hidden_size))
def test_state_dict_compatibility(self):
model1 = LlamaForCausalLMEagle3(self.config)
model2 = LlamaForCausalLMEagle3(self.config)
state_dict = model1.state_dict()
model2.load_state_dict(state_dict)
for name, param1 in model1.named_parameters():
param2 = dict(model2.named_parameters())[name]
self.assertTrue(torch.equal(param1, param2))
def test_config_validation(self):
invalid_config = LlamaConfig(
vocab_size=1000,
hidden_size=127,
num_attention_heads=4,
num_key_value_heads=2,
)
with self.assertRaises(AttributeError):
LlamaForCausalLMEagle3(invalid_config)
if __name__ == "__main__":
suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(TestLlamaForCausalLMEagle3Loading))
runner = unittest.TextTestRunner(verbosity=2)
runner.run(suite)
|