| from ..torch_core import * |
| from ..callback import * |
| from ..basic_train import Learner, LearnerCallback |
|
|
| __all__ = ['LossMetrics'] |
|
|
| class LossMetrics(LearnerCallback): |
| "Add `loss_func.metrics` to metrics named by `loss_func.metric_names`" |
| _order = -20 |
|
|
| def on_train_begin(self, **kwargs): |
| "Add the metrics names to the `Recorder`." |
| self.names = ifnone(self.learn.loss_func.metric_names, []) |
| if not self.names: warn('LossMetrics requested but no loss_func.metric_names provided') |
| self.learn.recorder.add_metric_names(self.names) |
|
|
| def on_epoch_begin(self, **kwargs): |
| "Initialize the metrics for this epoch." |
| self.metrics = {name:0. for name in self.names} |
| self.nums = 0 |
|
|
| def on_batch_end(self, last_target, train, **kwargs): |
| "Update the metrics if not `train`" |
| if train: return |
| bs = last_target.size(0) |
| for name in self.names: |
| self.metrics[name] += bs * self.learn.loss_func.metrics[name].detach().cpu() |
| self.nums += bs |
|
|
| def on_epoch_end(self, last_metrics, **kwargs): |
| "Finish the computation and sends the result to the Recorder." |
| if not self.nums: return |
| metrics = [self.metrics[name]/self.nums for name in self.names] |
| return {'last_metrics': last_metrics+metrics} |
|
|