Respair's picture
Upload folder using huggingface_hub
b386992 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import MagicMock, patch
import pytest
import torch
from nemo.collections.llm.bert.model.base import (
BertConfig,
BertModel,
bert_forward_step,
get_batch_on_this_cp_rank,
get_packed_seq_params,
)
class TestBertBase:
@pytest.fixture
def sample_batch(self):
return {
"text": torch.randint(0, 1000, (32, 128)), # [batch_size, seq_len]
"padding_mask": torch.ones(32, 128), # [batch_size, seq_len]
"labels": torch.randint(0, 1000, (32, 128)), # [batch_size, seq_len]
"loss_mask": torch.ones(32, 128), # [batch_size, seq_len]
"types": torch.zeros(32, 128), # [batch_size, seq_len]
}
@pytest.fixture
def sample_packed_batch(self):
return {
"cu_seqlens": torch.tensor([0, 10, 25, 35, -1, -1]),
"max_seqlen": torch.tensor(15),
}
@pytest.fixture
def basic_config(self):
return BertConfig(
num_layers=6,
hidden_size=768,
num_attention_heads=12,
)
def test_get_batch_on_this_cp_rank_no_cp(self, sample_batch):
with patch('megatron.core.parallel_state') as mock_parallel_state:
mock_parallel_state.get_context_parallel_world_size.return_value = 1
result = get_batch_on_this_cp_rank(sample_batch)
# When context parallel size is 1, should return original batch unchanged
assert result == sample_batch
def test_get_packed_seq_params(self, sample_packed_batch):
params = get_packed_seq_params(sample_packed_batch)
# Check that cu_seqlens was properly trimmed using cu_seqlens_argmin
assert params.cu_seqlens_q.shape[0] == 4 # Should trim at index 4
assert params.max_seqlen_q == sample_packed_batch["max_seqlen"]
assert params.qkv_format == "thd"
def test_bert_config_initialization(self, basic_config):
assert basic_config.num_layers == 6
assert basic_config.hidden_size == 768
assert basic_config.num_attention_heads == 12
assert basic_config.bert_type == "megatron" # default value
assert basic_config.add_pooler is True # default value
def test_bert_model_initialization(self, basic_config):
tokenizer = MagicMock()
tokenizer.vocab_size = 30000
model = BertModel(config=basic_config, tokenizer=tokenizer)
assert model.config == basic_config
assert model.tokenizer == tokenizer
assert hasattr(model, "optim")
def test_bert_forward_step(self, basic_config, sample_batch):
model = BertModel(config=basic_config, tokenizer=None)
model.module = MagicMock()
bert_forward_step(model, sample_batch)
# Verify model was called with correct arguments
model.module.assert_called_once()
call_args = model.module.call_args[1]
assert "input_ids" in call_args
assert "attention_mask" in call_args
assert "lm_labels" in call_args
assert "loss_mask" in call_args
def test_bert_forward_step_with_tokentypes(self, basic_config, sample_batch):
basic_config.num_tokentypes = 2
model = BertModel(config=basic_config, tokenizer=None)
model.module = MagicMock()
result = bert_forward_step(model, sample_batch)
# Verify tokentype_ids was included in forward call
call_args = model.module.call_args[1]
assert "tokentype_ids" in call_args
def test_bert_forward_step_with_packed_seqs(self, basic_config, sample_batch):
model = BertModel(config=basic_config, tokenizer=None)
model.module = MagicMock()
# Add packed sequence params to batch
sample_batch["cu_seqlens"] = torch.tensor([0, 10, 20])
result = bert_forward_step(model, sample_batch)
# Verify packed_seq_params was included
call_args = model.module.call_args[1]
assert "packed_seq_params" in call_args
def test_bert_model_training_step(self, basic_config):
model = BertModel(config=basic_config, tokenizer=None)
model.forward_step = MagicMock()
batch = {"dummy": "batch"}
output = model.training_step(batch)
model.forward_step.assert_called_once_with(batch)
def test_bert_model_validation_step(self, basic_config):
model = BertModel(config=basic_config, tokenizer=None)
model.forward_step = MagicMock()
batch = {"dummy": "batch"}
output = model.validation_step(batch)
model.forward_step.assert_called_once_with(batch)
def test_get_batch_with_context_parallel(self, sample_batch):
with patch('megatron.core.parallel_state') as mock_parallel_state:
mock_parallel_state.get_context_parallel_world_size.return_value = 2
mock_parallel_state.get_context_parallel_rank.return_value = 0
result = get_batch_on_this_cp_rank(sample_batch)
# Verify batch was properly split for context parallel processing
for key, val in result.items():
if val is not None:
if key != "attention_mask":
assert val.shape[1] == sample_batch[key].shape[1]
else:
assert val.shape[2] == sample_batch[key].shape[2]