File size: 3,673 Bytes
bcdf9fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
import torch.distributed
class Profiler:
def __init__(self, config):
# note : if we do not set use_profile, it will be set as None, so that all function will be skip
self.config = config
self.skip_prof = False
self.saved = False
self.prof = None
self.rank = torch.distributed.get_rank()
# we need to validate the config before using the profiler
self._validate()
if config.use_profile and self.rank in self.config.profile_ranks:
print(f"[Profiler] Profiler init for rank {self.rank}")
self.prof = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(
wait=max(self.config.step_start - 1, 0),
warmup=1 if self.config.step_start > 0 else 0,
active=self.config.step_end - self.config.step_start,
repeat=1,
),
record_shapes=True,
with_stack=True,
)
def _validate(self):
if self.config.use_profile:
if self.config.profile_ranks is None:
print("[WARNING] Profile ranks is not set, default to rank 0")
self.config.profile_ranks = [0]
assert self.config.step_start >= 0, "[ERROR] Profile step start must be greater than 0"
assert self.config.step_end >= 0, "[ERROR] Profile step end must be greater than 0"
assert self.config.step_start < self.config.step_end, "[ERROR] Profile step start must be less than step end"
def check(self):
return self.prof is not None and not self.skip_prof
def start(self):
if self.check():
print(f"[Profiler] started for rank {self.rank}")
self.prof.start()
def step(self):
if self.check():
self.prof.step()
def stop(self):
if self.check():
print(f"[Profiler] stopped for rank {self.rank}")
self.prof.stop()
def save(self):
if self.prof is not None and not self.saved:
if not os.path.exists(self.config.save_path):
os.makedirs(self.config.save_path)
save_file_name = f"/prof_start_{self.config.step_start}_end_{self.config.step_end}_rank_{self.rank}.json"
print(f"[Profiler] Saving trace to {self.config.save_path + save_file_name}")
self.prof.export_chrome_trace(self.config.save_path + save_file_name)
self.skip_prof = True
self.saved = True
def stop_and_save(self):
if self.check():
self.stop()
self.save()
def stop_trace(self):
if self.check():
print(f"[Profiler] Trace stopped for rank {self.rank}")
self.skip_prof = True
|