| import time |
|
|
|
|
| class TimeEstimator: |
|
|
| def __init__(self, total_iter: int, step_size: int, ema_alpha: float = 0.7): |
| self.avg_time_window = [] |
| self.exp_avg_time = None |
| self.alpha = ema_alpha |
|
|
| self.last_time = time.time() |
| self.total_iter = total_iter |
| self.step_size = step_size |
|
|
| self._buffering_exp = True |
|
|
| |
| |
| def update(self): |
| curr_time = time.time() |
| time_per_iter = curr_time - self.last_time |
| self.last_time = curr_time |
|
|
| self.avg_time_window.append(time_per_iter) |
|
|
| if self._buffering_exp: |
| if self.exp_avg_time is not None: |
| |
| self._buffering_exp = False |
| self.exp_avg_time = time_per_iter |
| else: |
| self.exp_avg_time = self.alpha * self.exp_avg_time + (1 - self.alpha) * time_per_iter |
|
|
| def get_est_remaining(self, it: int): |
| if self.exp_avg_time is None: |
| return 0 |
|
|
| remaining_iter = self.total_iter - it |
| return remaining_iter * self.exp_avg_time / self.step_size |
|
|
| def get_and_reset_avg_time(self): |
| avg = sum(self.avg_time_window) / len(self.avg_time_window) / self.step_size |
| self.avg_time_window = [] |
| return avg |
|
|
|
|
| class PartialTimeEstimator(TimeEstimator): |
| """ |
| Used where the start_time and the end_time do not align |
| """ |
|
|
| def update(self): |
| raise RuntimeError('Please use start() and end() for PartialTimeEstimator') |
|
|
| def start(self): |
| self.last_time = time.time() |
|
|
| def end(self): |
| assert self.last_time is not None, 'Please call start() before calling end()' |
| curr_time = time.time() |
| time_per_iter = curr_time - self.last_time |
| self.last_time = None |
|
|
| self.avg_time_window.append(time_per_iter) |
|
|
| if self._buffering_exp: |
| if self.exp_avg_time is not None: |
| |
| self._buffering_exp = False |
| self.exp_avg_time = time_per_iter |
| else: |
| self.exp_avg_time = self.alpha * self.exp_avg_time + (1 - self.alpha) * time_per_iter |
|
|