| |
| |
| |
| |
| @@ -71,6 +71,8 @@ from sglang.srt.utils import ( |
| suppress_other_loggers, |
| ) |
| from sglang.utils import get_exception_traceback |
| +from rpdTracerControl import rpdTracerControl |
| +rpdTracerControl.skipCreate() |
|
|
| logger = logging.getLogger(__name__) |
|
|
| @@ -245,6 +247,7 @@ class Scheduler: |
| ], |
| with_stack=True, |
| ) |
| + self.rpd = rpdTracerControl() |
|
|
| @torch.inference_mode() |
| def event_loop(self): |
| @@ -1027,15 +1030,24 @@ class Scheduler: |
| def start_profile(self) -> None: |
| if self.profiler is None: |
| raise RuntimeError("Profiler is not enabled.") |
| - self.profiler.start() |
| + #self.profiler.start() #block pytorch profiler for rpd profiler enabling |
| + if self.tp_rank == 0 or self.tp_rank == 1: |
| + self.rpd.start() |
| + self.rpd.rangePush("", "rpd profile range", "") |
| + logger.info("rpd is enabled") |
|
|
| def stop_profile(self) -> None: |
| if self.profiler is None: |
| raise RuntimeError("Profiler is not enabled.") |
| - self.profiler.stop() |
| - self.profiler.export_chrome_trace( |
| - self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz" |
| - ) |
| + #self.profiler.stop() |
| + #self.profiler.export_chrome_trace( |
| + # self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz" |
| + #) |
| + if self.tp_rank ==0 or self.tp_rank ==1: |
| + self.rpd.rangePop() |
| + self.rpd.stop() |
| + self.rpd.flush() |
| + logger.info("rpd is done") |
| logger.info("Profiler is done") |
|
|