| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | from unittest.mock import MagicMock, patch |
| |
|
| | import pytest |
| | import torch |
| |
|
| | from nemo.collections.llm.bert.model.base import ( |
| | BertConfig, |
| | BertModel, |
| | bert_forward_step, |
| | get_batch_on_this_cp_rank, |
| | get_packed_seq_params, |
| | ) |
| |
|
| |
|
| | class TestBertBase: |
| | @pytest.fixture |
| | def sample_batch(self): |
| | return { |
| | "text": torch.randint(0, 1000, (32, 128)), |
| | "padding_mask": torch.ones(32, 128), |
| | "labels": torch.randint(0, 1000, (32, 128)), |
| | "loss_mask": torch.ones(32, 128), |
| | "types": torch.zeros(32, 128), |
| | } |
| |
|
| | @pytest.fixture |
| | def sample_packed_batch(self): |
| | return { |
| | "cu_seqlens": torch.tensor([0, 10, 25, 35, -1, -1]), |
| | "max_seqlen": torch.tensor(15), |
| | } |
| |
|
| | @pytest.fixture |
| | def basic_config(self): |
| | return BertConfig( |
| | num_layers=6, |
| | hidden_size=768, |
| | num_attention_heads=12, |
| | ) |
| |
|
| | def test_get_batch_on_this_cp_rank_no_cp(self, sample_batch): |
| | with patch('megatron.core.parallel_state') as mock_parallel_state: |
| | mock_parallel_state.get_context_parallel_world_size.return_value = 1 |
| |
|
| | result = get_batch_on_this_cp_rank(sample_batch) |
| |
|
| | |
| | assert result == sample_batch |
| |
|
| | def test_get_packed_seq_params(self, sample_packed_batch): |
| | params = get_packed_seq_params(sample_packed_batch) |
| |
|
| | |
| | assert params.cu_seqlens_q.shape[0] == 4 |
| | assert params.max_seqlen_q == sample_packed_batch["max_seqlen"] |
| | assert params.qkv_format == "thd" |
| |
|
| | def test_bert_config_initialization(self, basic_config): |
| | assert basic_config.num_layers == 6 |
| | assert basic_config.hidden_size == 768 |
| | assert basic_config.num_attention_heads == 12 |
| | assert basic_config.bert_type == "megatron" |
| | assert basic_config.add_pooler is True |
| |
|
| | def test_bert_model_initialization(self, basic_config): |
| | tokenizer = MagicMock() |
| | tokenizer.vocab_size = 30000 |
| |
|
| | model = BertModel(config=basic_config, tokenizer=tokenizer) |
| |
|
| | assert model.config == basic_config |
| | assert model.tokenizer == tokenizer |
| | assert hasattr(model, "optim") |
| |
|
| | def test_bert_forward_step(self, basic_config, sample_batch): |
| | model = BertModel(config=basic_config, tokenizer=None) |
| | model.module = MagicMock() |
| |
|
| | bert_forward_step(model, sample_batch) |
| |
|
| | |
| | model.module.assert_called_once() |
| | call_args = model.module.call_args[1] |
| | assert "input_ids" in call_args |
| | assert "attention_mask" in call_args |
| | assert "lm_labels" in call_args |
| | assert "loss_mask" in call_args |
| |
|
| | def test_bert_forward_step_with_tokentypes(self, basic_config, sample_batch): |
| | basic_config.num_tokentypes = 2 |
| | model = BertModel(config=basic_config, tokenizer=None) |
| | model.module = MagicMock() |
| |
|
| | result = bert_forward_step(model, sample_batch) |
| |
|
| | |
| | call_args = model.module.call_args[1] |
| | assert "tokentype_ids" in call_args |
| |
|
| | def test_bert_forward_step_with_packed_seqs(self, basic_config, sample_batch): |
| | model = BertModel(config=basic_config, tokenizer=None) |
| | model.module = MagicMock() |
| |
|
| | |
| | sample_batch["cu_seqlens"] = torch.tensor([0, 10, 20]) |
| |
|
| | result = bert_forward_step(model, sample_batch) |
| |
|
| | |
| | call_args = model.module.call_args[1] |
| | assert "packed_seq_params" in call_args |
| |
|
| | def test_bert_model_training_step(self, basic_config): |
| | model = BertModel(config=basic_config, tokenizer=None) |
| | model.forward_step = MagicMock() |
| |
|
| | batch = {"dummy": "batch"} |
| | output = model.training_step(batch) |
| |
|
| | model.forward_step.assert_called_once_with(batch) |
| |
|
| | def test_bert_model_validation_step(self, basic_config): |
| | model = BertModel(config=basic_config, tokenizer=None) |
| | model.forward_step = MagicMock() |
| |
|
| | batch = {"dummy": "batch"} |
| | output = model.validation_step(batch) |
| |
|
| | model.forward_step.assert_called_once_with(batch) |
| |
|
| | def test_get_batch_with_context_parallel(self, sample_batch): |
| | with patch('megatron.core.parallel_state') as mock_parallel_state: |
| | mock_parallel_state.get_context_parallel_world_size.return_value = 2 |
| | mock_parallel_state.get_context_parallel_rank.return_value = 0 |
| |
|
| | result = get_batch_on_this_cp_rank(sample_batch) |
| |
|
| | |
| | for key, val in result.items(): |
| | if val is not None: |
| | if key != "attention_mask": |
| | assert val.shape[1] == sample_batch[key].shape[1] |
| | else: |
| | assert val.shape[2] == sample_batch[key].shape[2] |
| |
|