Respair's picture
Upload folder using huggingface_hub
b386992 verified
# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2024 Arc Institute. All rights reserved.
# Copyright (c) 2024 Michael Poli. All rights reserved.
# Copyright (c) 2024 Stanford University. All rights reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from dataclasses import dataclass, field
from unittest.mock import MagicMock
import pytest
import torch
from nemo.lightning.pytorch.callbacks.jit_transform import (
JitConfig,
JitTransform,
compile_module,
extract_module_attr_name,
get_modules_from_selector,
listify,
)
def test_extract_module_attr_name_with_module():
mock_pl_module = MagicMock(spec=[])
mock_pl_module.module = MagicMock()
assert extract_module_attr_name(mock_pl_module) == 'module', mock_pl_module
def test_extract_module_attr_name_with_model():
mock_pl_module = MagicMock(spec=[])
mock_pl_module.model = MagicMock()
assert extract_module_attr_name(mock_pl_module) == 'model', mock_pl_module
def test_extract_module_attr_name_raises():
mock_pl_module = MagicMock(spec=[])
# no 'module' or 'model'
with pytest.raises(ValueError, match="Expected lightning_module to have a .model or .module"):
extract_module_attr_name(mock_pl_module)
def test_listify_non_list():
assert listify("test") == ["test"]
def test_listify_list():
assert listify(["test"]) == ["test"]
def test_get_modules_from_selector_none_selector():
model = MagicMock()
collected = list(get_modules_from_selector(model, None))
assert collected == [model]
def test_get_modules_from_selector_empty_string():
model = MagicMock()
collected = list(get_modules_from_selector(model, ""))
assert collected == [model]
def test_get_modules_from_selector_star():
model = MagicMock()
collected = list(get_modules_from_selector(model, "*"))
assert collected == [model]
def test_get_modules_from_selector_exact_path():
# Example: model.encoder.layer
child_module = torch.nn.Linear(3, 3)
parent_module = torch.nn.Module()
parent_module.encoder = torch.nn.Module()
parent_module.encoder.layer = child_module
collected = list(get_modules_from_selector(parent_module, "encoder.layer"))
assert collected == [child_module]
def test_get_modules_from_selector_non_existent_attr():
parent_module = torch.nn.Module()
parent_module.encoder = torch.nn.Module()
with pytest.raises(AttributeError, match="has no attribute"):
list(get_modules_from_selector(parent_module, "decoder"))
def test_get_modules_from_selector_attr_is_not_module():
parent_module = torch.nn.Module()
parent_module.something = "I am not a module"
with pytest.raises(AttributeError, match="is not an nn.Module"):
list(get_modules_from_selector(parent_module, "something"))
def test_get_modules_from_selector_wildcard_children():
parent_module = torch.nn.Module()
parent_module.block1 = torch.nn.Linear(3, 3)
parent_module.block2 = torch.nn.Linear(3, 3)
collected = list(get_modules_from_selector(parent_module, "block*"))
assert len(collected) == 2
def test_jit_config_assertion():
# Should raise if both use_torch and use_thunder
with pytest.raises(AssertionError):
JitConfig(use_torch=True, use_thunder=True)
def test_compile_module_torch():
mock_module = MagicMock()
config = JitConfig(use_torch=True, torch_kwargs={"some_arg": 123})
compiled = compile_module(config, mock_module)
mock_module.compile.assert_called_once_with(some_arg=123)
assert compiled
# Disabling due to issue with 25.03 https://github.com/pytorch/pytorch/issues/144567
# def test_compile_module_thunder():
# mock_module = MagicMock()
# config = JitConfig(use_thunder=True)
# compiled = compile_module(config, mock_module)
# mock_module.compile.assert_called_once()
# assert compiled
def test_compile_module_none():
mock_module = MagicMock()
config = JitConfig()
compiled = compile_module(config, mock_module)
mock_module.compile.assert_not_called()
assert not compiled
def test_jit_transform_no_config():
# If config is None, on_train_epoch_start returns early
transform = JitTransform(JitConfig(use_thunder=False, use_torch=False))
trainer_mock = MagicMock()
pl_module = MagicMock(spec=[])
transform.on_train_epoch_start(trainer_mock, pl_module)
assert not getattr(pl_module, '_compiled', False)
def test_jit_transform_already_compiled():
transform = JitTransform(JitConfig(use_torch=True))
trainer_mock = MagicMock()
pl_module = MagicMock(spec=[])
pl_module._compiled = True
pl_module.module = True
transform.on_train_epoch_start(trainer_mock, pl_module)
# Should remain True, and compile should not be called again
assert pl_module._compiled is True
assert pl_module.module == True
def test_jit_transform_compile_once():
# simulate successful compile (torch or thunder)
transform = JitTransform(JitConfig(use_torch=True))
trainer_mock = MagicMock()
# pl_module with the 'module' attribute (matching whatever name you expect inside transform)
pl_module = MagicMock()
pl_module.module = MagicMock()
transform.on_train_epoch_start(trainer_mock, pl_module)
assert pl_module._compiled is True