| from ..core import * |
| from ..callback import * |
| from ..basic_train import Learner, LearnerCallback |
|
|
| __all__ = ['GeneralScheduler', 'TrainingPhase'] |
|
|
| @dataclass |
| class TrainingPhase(): |
| "Schedule hyper-parameters for a phase of `length` iterations." |
| length:int |
| |
| def __post_init__(self): self.scheds = dict() |
| def schedule_hp(self, name, vals, anneal=None): |
| "Adds a schedule for `name` between `vals` using `anneal`." |
| self.scheds[name] = Scheduler(vals, self.length, anneal) |
| return self |
|
|
| class GeneralScheduler(LearnerCallback): |
| "Schedule multiple `TrainingPhase` for a `Learner`." |
| def __init__(self, learn:Learner, phases:Collection[TrainingPhase], start_epoch:int=None): |
| super().__init__(learn) |
| self.phases,self.start_epoch = phases,start_epoch |
|
|
| def on_train_begin(self, epoch:int, **kwargs:Any)->None: |
| "Initialize the schedulers for training." |
| res = {'epoch':self.start_epoch} if self.start_epoch is not None else None |
| self.start_epoch = ifnone(self.start_epoch, epoch) |
| self.scheds = [p.scheds for p in self.phases] |
| self.opt = self.learn.opt |
| for k,v in self.scheds[0].items(): |
| v.restart() |
| self.opt.set_stat(k, v.start) |
| self.idx_s = 0 |
| return res |
| |
| def jump_to_epoch(self, epoch:int)->None: |
| for _ in range(len(self.learn.data.train_dl) * epoch): |
| self.on_batch_end(True) |
|
|
| def on_batch_end(self, train, **kwargs:Any)->None: |
| "Take a step in lr,mom sched, start next stepper when the current one is complete." |
| if train: |
| if self.idx_s >= len(self.scheds): return {'stop_training': True, 'stop_epoch': True} |
| sched = self.scheds[self.idx_s] |
| for k,v in sched.items(): self.opt.set_stat(k, v.step()) |
| if list(sched.values())[0].is_done: self.idx_s += 1 |