| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | from types import SimpleNamespace |
| |
|
| | import torch.nn as nn |
| |
|
| | from nemo.collections.common.parts.optional_cuda_graphs import WithOptionalCudaGraphs |
| |
|
| |
|
| | class MockClassWithCudaGraphs(WithOptionalCudaGraphs): |
| | def __init__(self): |
| | super().__init__() |
| | self.cuda_graphs_used = True |
| |
|
| | def disable_cuda_graphs(self): |
| | self.cuda_graphs_used = False |
| |
|
| | def maybe_enable_cuda_graphs(self): |
| | self.cuda_graphs_used = True |
| |
|
| |
|
| | class MockModuleWithCudaGraphs(MockClassWithCudaGraphs, nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | self.linear = nn.Linear(10, 20) |
| |
|
| | def forward(self, x): |
| | return self.linear(x) |
| |
|
| |
|
| | class MockModuleWithCudaGraphsByPath(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | self.linear = nn.Linear(10, 20) |
| | self.decoding = SimpleNamespace(decoding=MockClassWithCudaGraphs()) |
| |
|
| | def forward(self, x): |
| | return self.linear(x) |
| |
|
| |
|
| | class TestWithOptionalCudaGraphs: |
| | def test_module_toggle_cuda_graphs(self): |
| | module_with_graphs = MockModuleWithCudaGraphs() |
| | assert module_with_graphs.cuda_graphs_used |
| | WithOptionalCudaGraphs.disable_cuda_graphs_recursive(module_with_graphs) |
| | assert not module_with_graphs.cuda_graphs_used |
| | WithOptionalCudaGraphs.enable_cuda_graphs_recursive(module_with_graphs) |
| | assert module_with_graphs.cuda_graphs_used |
| |
|
| | def test_module_toggle_cuda_graphs_by_path(self): |
| | module_with_graphs_by_path = MockModuleWithCudaGraphsByPath() |
| | assert module_with_graphs_by_path.decoding.decoding.cuda_graphs_used |
| | WithOptionalCudaGraphs.disable_cuda_graphs_recursive( |
| | module_with_graphs_by_path, attribute_path="decoding.decoding" |
| | ) |
| | assert not module_with_graphs_by_path.decoding.decoding.cuda_graphs_used |
| | WithOptionalCudaGraphs.enable_cuda_graphs_recursive( |
| | module_with_graphs_by_path, attribute_path="decoding.decoding" |
| | ) |
| | assert module_with_graphs_by_path.decoding.decoding.cuda_graphs_used |
| |
|