Spaces:
Build error
Build error
| # Ultralytics YOLO 🚀, AGPL-3.0 license | |
| from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, checks | |
| try: | |
| assert not TESTS_RUNNING # do not log pytest | |
| assert SETTINGS["dvc"] is True # verify integration is enabled | |
| import dvclive | |
| assert checks.check_version("dvclive", "2.11.0", verbose=True) | |
| import os | |
| import re | |
| from pathlib import Path | |
| # DVCLive logger instance | |
| live = None | |
| _processed_plots = {} | |
| # `on_fit_epoch_end` is called on final validation (probably need to be fixed) for now this is the way we | |
| # distinguish final evaluation of the best model vs last epoch validation | |
| _training_epoch = False | |
| except (ImportError, AssertionError, TypeError): | |
| dvclive = None | |
| def _log_images(path, prefix=""): | |
| """Logs images at specified path with an optional prefix using DVCLive.""" | |
| if live: | |
| name = path.name | |
| # Group images by batch to enable sliders in UI | |
| if m := re.search(r"_batch(\d+)", name): | |
| ni = m[1] | |
| new_stem = re.sub(r"_batch(\d+)", "_batch", path.stem) | |
| name = (Path(new_stem) / ni).with_suffix(path.suffix) | |
| live.log_image(os.path.join(prefix, name), path) | |
| def _log_plots(plots, prefix=""): | |
| """Logs plot images for training progress if they have not been previously processed.""" | |
| for name, params in plots.items(): | |
| timestamp = params["timestamp"] | |
| if _processed_plots.get(name) != timestamp: | |
| _log_images(name, prefix) | |
| _processed_plots[name] = timestamp | |
| def _log_confusion_matrix(validator): | |
| """Logs the confusion matrix for the given validator using DVCLive.""" | |
| targets = [] | |
| preds = [] | |
| matrix = validator.confusion_matrix.matrix | |
| names = list(validator.names.values()) | |
| if validator.confusion_matrix.task == "detect": | |
| names += ["background"] | |
| for ti, pred in enumerate(matrix.T.astype(int)): | |
| for pi, num in enumerate(pred): | |
| targets.extend([names[ti]] * num) | |
| preds.extend([names[pi]] * num) | |
| live.log_sklearn_plot("confusion_matrix", targets, preds, name="cf.json", normalized=True) | |
| def on_pretrain_routine_start(trainer): | |
| """Initializes DVCLive logger for training metadata during pre-training routine.""" | |
| try: | |
| global live | |
| live = dvclive.Live(save_dvc_exp=True, cache_images=True) | |
| LOGGER.info("DVCLive is detected and auto logging is enabled (run 'yolo settings dvc=False' to disable).") | |
| except Exception as e: | |
| LOGGER.warning(f"WARNING ⚠️ DVCLive installed but not initialized correctly, not logging this run. {e}") | |
| def on_pretrain_routine_end(trainer): | |
| """Logs plots related to the training process at the end of the pretraining routine.""" | |
| _log_plots(trainer.plots, "train") | |
| def on_train_start(trainer): | |
| """Logs the training parameters if DVCLive logging is active.""" | |
| if live: | |
| live.log_params(trainer.args) | |
| def on_train_epoch_start(trainer): | |
| """Sets the global variable _training_epoch value to True at the start of training each epoch.""" | |
| global _training_epoch | |
| _training_epoch = True | |
| def on_fit_epoch_end(trainer): | |
| """Logs training metrics and model info, and advances to next step on the end of each fit epoch.""" | |
| global _training_epoch | |
| if live and _training_epoch: | |
| all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics, **trainer.lr} | |
| for metric, value in all_metrics.items(): | |
| live.log_metric(metric, value) | |
| if trainer.epoch == 0: | |
| from ultralytics.utils.torch_utils import model_info_for_loggers | |
| for metric, value in model_info_for_loggers(trainer).items(): | |
| live.log_metric(metric, value, plot=False) | |
| _log_plots(trainer.plots, "train") | |
| _log_plots(trainer.validator.plots, "val") | |
| live.next_step() | |
| _training_epoch = False | |
| def on_train_end(trainer): | |
| """Logs the best metrics, plots, and confusion matrix at the end of training if DVCLive is active.""" | |
| if live: | |
| # At the end log the best metrics. It runs validator on the best model internally. | |
| all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics, **trainer.lr} | |
| for metric, value in all_metrics.items(): | |
| live.log_metric(metric, value, plot=False) | |
| _log_plots(trainer.plots, "val") | |
| _log_plots(trainer.validator.plots, "val") | |
| _log_confusion_matrix(trainer.validator) | |
| if trainer.best.exists(): | |
| live.log_artifact(trainer.best, copy=True, type="model") | |
| live.end() | |
| callbacks = ( | |
| { | |
| "on_pretrain_routine_start": on_pretrain_routine_start, | |
| "on_pretrain_routine_end": on_pretrain_routine_end, | |
| "on_train_start": on_train_start, | |
| "on_train_epoch_start": on_train_epoch_start, | |
| "on_fit_epoch_end": on_fit_epoch_end, | |
| "on_train_end": on_train_end, | |
| } | |
| if dvclive | |
| else {} | |
| ) | |