File size: 4,573 Bytes
b386992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys

import lightning.pytorch as pl
import pytest
import torch
from omegaconf import OmegaConf

from nemo.core.classes import ModelPT
from nemo.utils.exp_manager import exp_manager

try:
    # `ptl_resiliency` is included in `gwe_resiliency_pkg` package
    from ptl_resiliency import StragglerDetectionCallback

    HAVE_STRAGGLER_DET = True
except (ImportError, ModuleNotFoundError):
    HAVE_STRAGGLER_DET = False


class OnesDataset(torch.utils.data.Dataset):
    def __init__(self, dataset_len):
        super().__init__()
        self.__dataset_len = dataset_len

    def __getitem__(self, *args):
        return torch.ones(2)

    def __len__(self):
        return self.__dataset_len


class ExampleModel(ModelPT):
    def __init__(self, log_dir, **kwargs):
        cfg = OmegaConf.structured({})
        super().__init__(cfg)
        pl.seed_everything(1234)
        self.l1 = torch.nn.modules.Linear(in_features=2, out_features=1)
        self.log_dir = log_dir

    def on_train_start(self):
        super().on_train_start()
        rank = torch.distributed.get_rank()

    def train_dataloader(self):
        dataset = OnesDataset(1024 * 1024)
        return torch.utils.data.DataLoader(dataset, batch_size=2, num_workers=2)

    def val_dataloader(self):
        dataset = OnesDataset(128 * 1024)
        return torch.utils.data.DataLoader(dataset, batch_size=2, num_workers=2)

    def forward(self, batch):
        output = self.l1(batch)
        output = torch.nn.functional.l1_loss(output, torch.zeros(output.size()).to(output.device))
        return output

    def validation_step(self, batch, batch_idx):
        self.loss = self(batch)
        return self.loss

    def training_step(self, batch, batch_idx):
        return self(batch)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.1)

    def list_available_models(self, *args, **kwargs):
        pass

    def setup_training_data(self, *args, **kwargs):
        pass

    def setup_validation_data(self, *args, **kwargs):
        pass

    def on_validation_epoch_end(self):
        self.log("val_loss", torch.stack([self.loss]).mean())


@pytest.mark.skipif(not HAVE_STRAGGLER_DET, reason="requires resiliency package to be installed.")
class TestStragglerDetection:

    @pytest.mark.run_only_on('GPU')
    def test_prints_perf_scores(self, tmp_path):
        # Run dummy 1 rank DDP training
        # Training time is limited to 3 seconds and straggler reporting is set to 1 second
        # Check if there are straggler related logs in the captured log
        max_steps = 1_000_000
        tmp_path = tmp_path / "test_1"
        print("TMP PATH", tmp_path)

        trainer = pl.Trainer(
            strategy='ddp',
            devices=1,
            accelerator='gpu',
            enable_checkpointing=False,
            logger=False,
            max_steps=max_steps,
            val_check_interval=0.33,
        )
        exp_manager(
            trainer,
            {
                "max_time_per_run": "00:00:00:03",
                "explicit_log_dir": str(tmp_path),
                "create_checkpoint_callback": False,
                "create_straggler_detection_callback": True,
                "straggler_detection_params": {
                    "report_time_interval": 1.0,
                    "calc_relative_gpu_perf": True,
                    "calc_individual_gpu_perf": True,
                    "num_gpu_perf_scores_to_log": 1,
                },
            },
        )
        model = ExampleModel(log_dir=tmp_path)
        trainer.fit(model)

        # assume that NeMo logs are written into "nemo_log_globalrank-0_localrank-0.txt"
        rank0_log_content = None
        with open(tmp_path / "nemo_log_globalrank-0_localrank-0.txt") as f:
            rank0_log_content = f.read()

        assert "GPU relative performance" in rank0_log_content
        assert "GPU individual performance" in rank0_log_content