|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import re |
|
|
from dataclasses import dataclass, field |
|
|
from unittest.mock import MagicMock |
|
|
|
|
|
import pytest |
|
|
import torch |
|
|
|
|
|
from nemo.lightning.pytorch.callbacks.jit_transform import ( |
|
|
JitConfig, |
|
|
JitTransform, |
|
|
compile_module, |
|
|
extract_module_attr_name, |
|
|
get_modules_from_selector, |
|
|
listify, |
|
|
) |
|
|
|
|
|
|
|
|
def test_extract_module_attr_name_with_module(): |
|
|
mock_pl_module = MagicMock(spec=[]) |
|
|
mock_pl_module.module = MagicMock() |
|
|
assert extract_module_attr_name(mock_pl_module) == 'module', mock_pl_module |
|
|
|
|
|
|
|
|
def test_extract_module_attr_name_with_model(): |
|
|
mock_pl_module = MagicMock(spec=[]) |
|
|
mock_pl_module.model = MagicMock() |
|
|
assert extract_module_attr_name(mock_pl_module) == 'model', mock_pl_module |
|
|
|
|
|
|
|
|
def test_extract_module_attr_name_raises(): |
|
|
mock_pl_module = MagicMock(spec=[]) |
|
|
|
|
|
with pytest.raises(ValueError, match="Expected lightning_module to have a .model or .module"): |
|
|
extract_module_attr_name(mock_pl_module) |
|
|
|
|
|
|
|
|
def test_listify_non_list(): |
|
|
assert listify("test") == ["test"] |
|
|
|
|
|
|
|
|
def test_listify_list(): |
|
|
assert listify(["test"]) == ["test"] |
|
|
|
|
|
|
|
|
def test_get_modules_from_selector_none_selector(): |
|
|
model = MagicMock() |
|
|
collected = list(get_modules_from_selector(model, None)) |
|
|
assert collected == [model] |
|
|
|
|
|
|
|
|
def test_get_modules_from_selector_empty_string(): |
|
|
model = MagicMock() |
|
|
collected = list(get_modules_from_selector(model, "")) |
|
|
assert collected == [model] |
|
|
|
|
|
|
|
|
def test_get_modules_from_selector_star(): |
|
|
model = MagicMock() |
|
|
collected = list(get_modules_from_selector(model, "*")) |
|
|
assert collected == [model] |
|
|
|
|
|
|
|
|
def test_get_modules_from_selector_exact_path(): |
|
|
|
|
|
child_module = torch.nn.Linear(3, 3) |
|
|
parent_module = torch.nn.Module() |
|
|
parent_module.encoder = torch.nn.Module() |
|
|
parent_module.encoder.layer = child_module |
|
|
|
|
|
collected = list(get_modules_from_selector(parent_module, "encoder.layer")) |
|
|
assert collected == [child_module] |
|
|
|
|
|
|
|
|
def test_get_modules_from_selector_non_existent_attr(): |
|
|
parent_module = torch.nn.Module() |
|
|
parent_module.encoder = torch.nn.Module() |
|
|
with pytest.raises(AttributeError, match="has no attribute"): |
|
|
list(get_modules_from_selector(parent_module, "decoder")) |
|
|
|
|
|
|
|
|
def test_get_modules_from_selector_attr_is_not_module(): |
|
|
parent_module = torch.nn.Module() |
|
|
parent_module.something = "I am not a module" |
|
|
with pytest.raises(AttributeError, match="is not an nn.Module"): |
|
|
list(get_modules_from_selector(parent_module, "something")) |
|
|
|
|
|
|
|
|
def test_get_modules_from_selector_wildcard_children(): |
|
|
parent_module = torch.nn.Module() |
|
|
parent_module.block1 = torch.nn.Linear(3, 3) |
|
|
parent_module.block2 = torch.nn.Linear(3, 3) |
|
|
collected = list(get_modules_from_selector(parent_module, "block*")) |
|
|
assert len(collected) == 2 |
|
|
|
|
|
|
|
|
def test_jit_config_assertion(): |
|
|
|
|
|
with pytest.raises(AssertionError): |
|
|
JitConfig(use_torch=True, use_thunder=True) |
|
|
|
|
|
|
|
|
def test_compile_module_torch(): |
|
|
mock_module = MagicMock() |
|
|
config = JitConfig(use_torch=True, torch_kwargs={"some_arg": 123}) |
|
|
compiled = compile_module(config, mock_module) |
|
|
mock_module.compile.assert_called_once_with(some_arg=123) |
|
|
assert compiled |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_compile_module_none(): |
|
|
mock_module = MagicMock() |
|
|
config = JitConfig() |
|
|
compiled = compile_module(config, mock_module) |
|
|
mock_module.compile.assert_not_called() |
|
|
assert not compiled |
|
|
|
|
|
|
|
|
def test_jit_transform_no_config(): |
|
|
|
|
|
transform = JitTransform(JitConfig(use_thunder=False, use_torch=False)) |
|
|
trainer_mock = MagicMock() |
|
|
pl_module = MagicMock(spec=[]) |
|
|
transform.on_train_epoch_start(trainer_mock, pl_module) |
|
|
assert not getattr(pl_module, '_compiled', False) |
|
|
|
|
|
|
|
|
def test_jit_transform_already_compiled(): |
|
|
transform = JitTransform(JitConfig(use_torch=True)) |
|
|
trainer_mock = MagicMock() |
|
|
pl_module = MagicMock(spec=[]) |
|
|
pl_module._compiled = True |
|
|
pl_module.module = True |
|
|
transform.on_train_epoch_start(trainer_mock, pl_module) |
|
|
|
|
|
assert pl_module._compiled is True |
|
|
assert pl_module.module == True |
|
|
|
|
|
|
|
|
def test_jit_transform_compile_once(): |
|
|
|
|
|
transform = JitTransform(JitConfig(use_torch=True)) |
|
|
trainer_mock = MagicMock() |
|
|
|
|
|
|
|
|
pl_module = MagicMock() |
|
|
pl_module.module = MagicMock() |
|
|
|
|
|
transform.on_train_epoch_start(trainer_mock, pl_module) |
|
|
assert pl_module._compiled is True |
|
|
|