Spaces:
Build error
Build error
| # Ultralytics YOLO π, AGPL-3.0 license | |
| """ | |
| Train a model on a dataset. | |
| Usage: | |
| $ yolo mode=train model=yolov8n.pt data=coco128.yaml imgsz=640 epochs=100 batch=16 | |
| """ | |
| import math | |
| import os | |
| import subprocess | |
| import time | |
| import warnings | |
| from copy import deepcopy | |
| from datetime import datetime, timedelta | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| from torch import distributed as dist | |
| from torch import nn, optim | |
| from ultralytics.cfg import get_cfg, get_save_dir | |
| from ultralytics.data.utils import check_cls_dataset, check_det_dataset | |
| from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights | |
| from ultralytics.utils import ( | |
| DEFAULT_CFG, | |
| LOGGER, | |
| RANK, | |
| TQDM, | |
| __version__, | |
| callbacks, | |
| clean_url, | |
| colorstr, | |
| emojis, | |
| yaml_save, | |
| ) | |
| from ultralytics.utils.autobatch import check_train_batch_size | |
| from ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_model_file_from_stem, print_args | |
| from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command | |
| from ultralytics.utils.files import get_latest_run | |
| from ultralytics.utils.torch_utils import ( | |
| EarlyStopping, | |
| ModelEMA, | |
| de_parallel, | |
| init_seeds, | |
| one_cycle, | |
| select_device, | |
| strip_optimizer, | |
| ) | |
| class BaseTrainer: | |
| """ | |
| BaseTrainer. | |
| A base class for creating trainers. | |
| Attributes: | |
| args (SimpleNamespace): Configuration for the trainer. | |
| validator (BaseValidator): Validator instance. | |
| model (nn.Module): Model instance. | |
| callbacks (defaultdict): Dictionary of callbacks. | |
| save_dir (Path): Directory to save results. | |
| wdir (Path): Directory to save weights. | |
| last (Path): Path to the last checkpoint. | |
| best (Path): Path to the best checkpoint. | |
| save_period (int): Save checkpoint every x epochs (disabled if < 1). | |
| batch_size (int): Batch size for training. | |
| epochs (int): Number of epochs to train for. | |
| start_epoch (int): Starting epoch for training. | |
| device (torch.device): Device to use for training. | |
| amp (bool): Flag to enable AMP (Automatic Mixed Precision). | |
| scaler (amp.GradScaler): Gradient scaler for AMP. | |
| data (str): Path to data. | |
| trainset (torch.utils.data.Dataset): Training dataset. | |
| testset (torch.utils.data.Dataset): Testing dataset. | |
| ema (nn.Module): EMA (Exponential Moving Average) of the model. | |
| resume (bool): Resume training from a checkpoint. | |
| lf (nn.Module): Loss function. | |
| scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler. | |
| best_fitness (float): The best fitness value achieved. | |
| fitness (float): Current fitness value. | |
| loss (float): Current loss value. | |
| tloss (float): Total loss value. | |
| loss_names (list): List of loss names. | |
| csv (Path): Path to results CSV file. | |
| """ | |
| def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): | |
| """ | |
| Initializes the BaseTrainer class. | |
| Args: | |
| cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG. | |
| overrides (dict, optional): Configuration overrides. Defaults to None. | |
| """ | |
| self.args = get_cfg(cfg, overrides) | |
| self.check_resume(overrides) | |
| self.device = select_device(self.args.device, self.args.batch) | |
| self.validator = None | |
| self.metrics = None | |
| self.plots = {} | |
| init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic) | |
| # Dirs | |
| self.save_dir = get_save_dir(self.args) | |
| self.args.name = self.save_dir.name # update name for loggers | |
| self.wdir = self.save_dir / "weights" # weights dir | |
| if RANK in (-1, 0): | |
| self.wdir.mkdir(parents=True, exist_ok=True) # make dir | |
| self.args.save_dir = str(self.save_dir) | |
| yaml_save(self.save_dir / "args.yaml", vars(self.args)) # save run args | |
| self.last, self.best = self.wdir / "last.pt", self.wdir / "best.pt" # checkpoint paths | |
| self.save_period = self.args.save_period | |
| self.batch_size = self.args.batch | |
| self.epochs = self.args.epochs | |
| self.start_epoch = 0 | |
| if RANK == -1: | |
| print_args(vars(self.args)) | |
| # Device | |
| if self.device.type in ("cpu", "mps"): | |
| self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading | |
| # Model and Dataset | |
| self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolov8n -> yolov8n.pt | |
| try: | |
| if self.args.task == "classify": | |
| self.data = check_cls_dataset(self.args.data) | |
| elif self.args.data.split(".")[-1] in ("yaml", "yml") or self.args.task in ( | |
| "detect", | |
| "segment", | |
| "pose", | |
| "obb", | |
| ): | |
| self.data = check_det_dataset(self.args.data) | |
| if "yaml_file" in self.data: | |
| self.args.data = self.data["yaml_file"] # for validating 'yolo train data=url.zip' usage | |
| except Exception as e: | |
| raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error β {e}")) from e | |
| self.trainset, self.testset = self.get_dataset(self.data) | |
| self.ema = None | |
| # Optimization utils init | |
| self.lf = None | |
| self.scheduler = None | |
| # Epoch level metrics | |
| self.best_fitness = None | |
| self.fitness = None | |
| self.loss = None | |
| self.tloss = None | |
| self.loss_names = ["Loss"] | |
| self.csv = self.save_dir / "results.csv" | |
| self.plot_idx = [0, 1, 2] | |
| # Callbacks | |
| self.callbacks = _callbacks or callbacks.get_default_callbacks() | |
| if RANK in (-1, 0): | |
| callbacks.add_integration_callbacks(self) | |
| def add_callback(self, event: str, callback): | |
| """Appends the given callback.""" | |
| self.callbacks[event].append(callback) | |
| def set_callback(self, event: str, callback): | |
| """Overrides the existing callbacks with the given callback.""" | |
| self.callbacks[event] = [callback] | |
| def run_callbacks(self, event: str): | |
| """Run all existing callbacks associated with a particular event.""" | |
| for callback in self.callbacks.get(event, []): | |
| callback(self) | |
| def train(self): | |
| """Allow device='', device=None on Multi-GPU systems to default to device=0.""" | |
| if isinstance(self.args.device, str) and len(self.args.device): # i.e. device='0' or device='0,1,2,3' | |
| world_size = len(self.args.device.split(",")) | |
| elif isinstance(self.args.device, (tuple, list)): # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list) | |
| world_size = len(self.args.device) | |
| elif torch.cuda.is_available(): # i.e. device=None or device='' or device=number | |
| world_size = 1 # default to device 0 | |
| else: # i.e. device='cpu' or 'mps' | |
| world_size = 0 | |
| # Run subprocess if DDP training, else train normally | |
| if world_size > 1 and "LOCAL_RANK" not in os.environ: | |
| # Argument checks | |
| if self.args.rect: | |
| LOGGER.warning("WARNING β οΈ 'rect=True' is incompatible with Multi-GPU training, setting 'rect=False'") | |
| self.args.rect = False | |
| if self.args.batch == -1: | |
| LOGGER.warning( | |
| "WARNING β οΈ 'batch=-1' for AutoBatch is incompatible with Multi-GPU training, setting " | |
| "default 'batch=16'" | |
| ) | |
| self.args.batch = 16 | |
| # Command | |
| cmd, file = generate_ddp_command(world_size, self) | |
| try: | |
| LOGGER.info(f'{colorstr("DDP:")} debug command {" ".join(cmd)}') | |
| subprocess.run(cmd, check=True) | |
| except Exception as e: | |
| raise e | |
| finally: | |
| ddp_cleanup(self, str(file)) | |
| else: | |
| self._do_train(world_size) | |
| def _setup_scheduler(self): | |
| """Initialize training learning rate scheduler.""" | |
| if self.args.cos_lr: | |
| self.lf = one_cycle(1, self.args.lrf, self.epochs) # cosine 1->hyp['lrf'] | |
| else: | |
| self.lf = lambda x: max(1 - x / self.epochs, 0) * (1.0 - self.args.lrf) + self.args.lrf # linear | |
| self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf) | |
| def _setup_ddp(self, world_size): | |
| """Initializes and sets the DistributedDataParallel parameters for training.""" | |
| torch.cuda.set_device(RANK) | |
| self.device = torch.device("cuda", RANK) | |
| # LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}') | |
| os.environ["NCCL_BLOCKING_WAIT"] = "1" # set to enforce timeout | |
| dist.init_process_group( | |
| backend="nccl" if dist.is_nccl_available() else "gloo", | |
| timeout=timedelta(seconds=10800), # 3 hours | |
| rank=RANK, | |
| world_size=world_size, | |
| ) | |
| def _setup_train(self, world_size): | |
| """Builds dataloaders and optimizer on correct rank process.""" | |
| # Model | |
| self.run_callbacks("on_pretrain_routine_start") | |
| ckpt = self.setup_model() | |
| self.model = self.model.to(self.device) | |
| self.set_model_attributes() | |
| # Freeze layers | |
| freeze_list = ( | |
| self.args.freeze | |
| if isinstance(self.args.freeze, list) | |
| else range(self.args.freeze) | |
| if isinstance(self.args.freeze, int) | |
| else [] | |
| ) | |
| always_freeze_names = [".dfl"] # always freeze these layers | |
| freeze_layer_names = [f"model.{x}." for x in freeze_list] + always_freeze_names | |
| for k, v in self.model.named_parameters(): | |
| # v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0 (commented for erratic training results) | |
| if any(x in k for x in freeze_layer_names): | |
| LOGGER.info(f"Freezing layer '{k}'") | |
| v.requires_grad = False | |
| elif not v.requires_grad and v.dtype.is_floating_point: # only floating point Tensor can require gradients | |
| LOGGER.info( | |
| f"WARNING β οΈ setting 'requires_grad=True' for frozen layer '{k}'. " | |
| "See ultralytics.engine.trainer for customization of frozen layers." | |
| ) | |
| v.requires_grad = True | |
| # Check AMP | |
| self.amp = torch.tensor(self.args.amp).to(self.device) # True or False | |
| if self.amp and RANK in (-1, 0): # Single-GPU and DDP | |
| callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them | |
| self.amp = torch.tensor(check_amp(self.model), device=self.device) | |
| callbacks.default_callbacks = callbacks_backup # restore callbacks | |
| if RANK > -1 and world_size > 1: # DDP | |
| dist.broadcast(self.amp, src=0) # broadcast the tensor from rank 0 to all other ranks (returns None) | |
| self.amp = bool(self.amp) # as boolean | |
| self.scaler = torch.cuda.amp.GradScaler(enabled=self.amp) | |
| if world_size > 1: | |
| self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK]) | |
| # Check imgsz | |
| gs = max(int(self.model.stride.max() if hasattr(self.model, "stride") else 32), 32) # grid size (max stride) | |
| self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1) | |
| self.stride = gs # for multiscale training | |
| # Batch size | |
| if self.batch_size == -1 and RANK == -1: # single-GPU only, estimate best batch size | |
| self.args.batch = self.batch_size = check_train_batch_size(self.model, self.args.imgsz, self.amp) | |
| # Dataloaders | |
| batch_size = self.batch_size // max(world_size, 1) | |
| self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode="train") | |
| if RANK in (-1, 0): | |
| # Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects. | |
| self.test_loader = self.get_dataloader( | |
| self.testset, batch_size=batch_size if self.args.task == "obb" else batch_size * 2, rank=-1, mode="val" | |
| ) | |
| self.validator = self.get_validator() | |
| metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix="val") | |
| self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) | |
| self.ema = ModelEMA(self.model) | |
| if self.args.plots: | |
| self.plot_training_labels() | |
| # Optimizer | |
| self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing | |
| weight_decay = self.args.weight_decay * self.batch_size * self.accumulate / self.args.nbs # scale weight_decay | |
| iterations = math.ceil(len(self.train_loader.dataset) / max(self.batch_size, self.args.nbs)) * self.epochs | |
| self.optimizer = self.build_optimizer( | |
| model=self.model, | |
| name=self.args.optimizer, | |
| lr=self.args.lr0, | |
| momentum=self.args.momentum, | |
| decay=weight_decay, | |
| iterations=iterations, | |
| ) | |
| # Scheduler | |
| self._setup_scheduler() | |
| self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False | |
| self.resume_training(ckpt) | |
| self.scheduler.last_epoch = self.start_epoch - 1 # do not move | |
| self.run_callbacks("on_pretrain_routine_end") | |
| def _do_train(self, world_size=1): | |
| """Train completed, evaluate and plot if specified by arguments.""" | |
| if world_size > 1: | |
| self._setup_ddp(world_size) | |
| self._setup_train(world_size) | |
| nb = len(self.train_loader) # number of batches | |
| nw = max(round(self.args.warmup_epochs * nb), 100) if self.args.warmup_epochs > 0 else -1 # warmup iterations | |
| last_opt_step = -1 | |
| self.epoch_time = None | |
| self.epoch_time_start = time.time() | |
| self.train_time_start = time.time() | |
| self.run_callbacks("on_train_start") | |
| LOGGER.info( | |
| f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n' | |
| f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n' | |
| f"Logging results to {colorstr('bold', self.save_dir)}\n" | |
| f'Starting training for ' + (f"{self.args.time} hours..." if self.args.time else f"{self.epochs} epochs...") | |
| ) | |
| if self.args.close_mosaic: | |
| base_idx = (self.epochs - self.args.close_mosaic) * nb | |
| self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2]) | |
| epoch = self.start_epoch | |
| while True: | |
| self.epoch = epoch | |
| self.run_callbacks("on_train_epoch_start") | |
| self.model.train() | |
| if RANK != -1: | |
| self.train_loader.sampler.set_epoch(epoch) | |
| pbar = enumerate(self.train_loader) | |
| # Update dataloader attributes (optional) | |
| if epoch == (self.epochs - self.args.close_mosaic): | |
| self._close_dataloader_mosaic() | |
| self.train_loader.reset() | |
| if RANK in (-1, 0): | |
| LOGGER.info(self.progress_string()) | |
| pbar = TQDM(enumerate(self.train_loader), total=nb) | |
| self.tloss = None | |
| self.optimizer.zero_grad() | |
| for i, batch in pbar: | |
| self.run_callbacks("on_train_batch_start") | |
| # Warmup | |
| ni = i + nb * epoch | |
| if ni <= nw: | |
| xi = [0, nw] # x interp | |
| self.accumulate = max(1, int(np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round())) | |
| for j, x in enumerate(self.optimizer.param_groups): | |
| # Bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0 | |
| x["lr"] = np.interp( | |
| ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x["initial_lr"] * self.lf(epoch)] | |
| ) | |
| if "momentum" in x: | |
| x["momentum"] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum]) | |
| # Forward | |
| with torch.cuda.amp.autocast(self.amp): | |
| batch = self.preprocess_batch(batch) | |
| self.loss, self.loss_items = self.model(batch) | |
| if RANK != -1: | |
| self.loss *= world_size | |
| self.tloss = ( | |
| (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None else self.loss_items | |
| ) | |
| # Backward | |
| self.scaler.scale(self.loss).backward() | |
| # Optimize - https://pytorch.org/docs/master/notes/amp_examples.html | |
| if ni - last_opt_step >= self.accumulate: | |
| self.optimizer_step() | |
| last_opt_step = ni | |
| # Timed stopping | |
| if self.args.time: | |
| self.stop = (time.time() - self.train_time_start) > (self.args.time * 3600) | |
| if RANK != -1: # if DDP training | |
| broadcast_list = [self.stop if RANK == 0 else None] | |
| dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks | |
| self.stop = broadcast_list[0] | |
| if self.stop: # training time exceeded | |
| break | |
| # Log | |
| mem = f"{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G" # (GB) | |
| loss_len = self.tloss.shape[0] if len(self.tloss.shape) else 1 | |
| losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0) | |
| if RANK in (-1, 0): | |
| pbar.set_description( | |
| ("%11s" * 2 + "%11.4g" * (2 + loss_len)) | |
| % (f"{epoch + 1}/{self.epochs}", mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1]) | |
| ) | |
| self.run_callbacks("on_batch_end") | |
| if self.args.plots and ni in self.plot_idx: | |
| self.plot_training_samples(batch, ni) | |
| self.run_callbacks("on_train_batch_end") | |
| self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers | |
| self.run_callbacks("on_train_epoch_end") | |
| if RANK in (-1, 0): | |
| final_epoch = epoch + 1 == self.epochs | |
| self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"]) | |
| # Validation | |
| if (self.args.val and (((epoch+1) % self.args.val_period == 0) or (self.epochs - epoch) <= 10)) \ | |
| or final_epoch or self.stopper.possible_stop or self.stop: | |
| self.metrics, self.fitness = self.validate() | |
| self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **self.lr}) | |
| self.stop |= self.stopper(epoch + 1, self.fitness) or final_epoch | |
| if self.args.time: | |
| self.stop |= (time.time() - self.train_time_start) > (self.args.time * 3600) | |
| # Save model | |
| if self.args.save or final_epoch: | |
| self.save_model() | |
| self.run_callbacks("on_model_save") | |
| # Scheduler | |
| t = time.time() | |
| self.epoch_time = t - self.epoch_time_start | |
| self.epoch_time_start = t | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter("ignore") # suppress 'Detected lr_scheduler.step() before optimizer.step()' | |
| if self.args.time: | |
| mean_epoch_time = (t - self.train_time_start) / (epoch - self.start_epoch + 1) | |
| self.epochs = self.args.epochs = math.ceil(self.args.time * 3600 / mean_epoch_time) | |
| self._setup_scheduler() | |
| self.scheduler.last_epoch = self.epoch # do not move | |
| self.stop |= epoch >= self.epochs # stop if exceeded epochs | |
| self.scheduler.step() | |
| self.run_callbacks("on_fit_epoch_end") | |
| torch.cuda.empty_cache() # clear GPU memory at end of epoch, may help reduce CUDA out of memory errors | |
| # Early Stopping | |
| if RANK != -1: # if DDP training | |
| broadcast_list = [self.stop if RANK == 0 else None] | |
| dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks | |
| self.stop = broadcast_list[0] | |
| if self.stop: | |
| break # must break all DDP ranks | |
| epoch += 1 | |
| if RANK in (-1, 0): | |
| # Do final val with best.pt | |
| LOGGER.info( | |
| f"\n{epoch - self.start_epoch + 1} epochs completed in " | |
| f"{(time.time() - self.train_time_start) / 3600:.3f} hours." | |
| ) | |
| self.final_eval() | |
| if self.args.plots: | |
| self.plot_metrics() | |
| self.run_callbacks("on_train_end") | |
| torch.cuda.empty_cache() | |
| self.run_callbacks("teardown") | |
| def save_model(self): | |
| """Save model training checkpoints with additional metadata.""" | |
| import pandas as pd # scope for faster startup | |
| metrics = {**self.metrics, **{"fitness": self.fitness}} | |
| results = {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()} | |
| ckpt = { | |
| "epoch": self.epoch, | |
| "best_fitness": self.best_fitness, | |
| "model": deepcopy(de_parallel(self.model)).half(), | |
| "ema": deepcopy(self.ema.ema).half(), | |
| "updates": self.ema.updates, | |
| "optimizer": self.optimizer.state_dict(), | |
| "train_args": vars(self.args), # save as dict | |
| "train_metrics": metrics, | |
| "train_results": results, | |
| "date": datetime.now().isoformat(), | |
| "version": __version__, | |
| "license": "AGPL-3.0 (https://ultralytics.com/license)", | |
| "docs": "https://docs.ultralytics.com", | |
| } | |
| # Save last and best | |
| torch.save(ckpt, self.last) | |
| if self.best_fitness == self.fitness: | |
| torch.save(ckpt, self.best) | |
| if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0): | |
| torch.save(ckpt, self.wdir / f"epoch{self.epoch}.pt") | |
| def get_dataset(data): | |
| """ | |
| Get train, val path from data dict if it exists. | |
| Returns None if data format is not recognized. | |
| """ | |
| return data["train"], data.get("val") or data.get("test") | |
| def setup_model(self): | |
| """Load/create/download model for any task.""" | |
| if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed | |
| return | |
| model, weights = self.model, None | |
| ckpt = None | |
| if str(model).endswith(".pt"): | |
| weights, ckpt = attempt_load_one_weight(model) | |
| cfg = ckpt["model"].yaml | |
| else: | |
| cfg = model | |
| self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights) | |
| return ckpt | |
| def optimizer_step(self): | |
| """Perform a single step of the training optimizer with gradient clipping and EMA update.""" | |
| self.scaler.unscale_(self.optimizer) # unscale gradients | |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0) # clip gradients | |
| self.scaler.step(self.optimizer) | |
| self.scaler.update() | |
| self.optimizer.zero_grad() | |
| if self.ema: | |
| self.ema.update(self.model) | |
| def preprocess_batch(self, batch): | |
| """Allows custom preprocessing model inputs and ground truths depending on task type.""" | |
| return batch | |
| def validate(self): | |
| """ | |
| Runs validation on test set using self.validator. | |
| The returned dict is expected to contain "fitness" key. | |
| """ | |
| metrics = self.validator(self) | |
| fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found | |
| if not self.best_fitness or self.best_fitness < fitness: | |
| self.best_fitness = fitness | |
| return metrics, fitness | |
| def get_model(self, cfg=None, weights=None, verbose=True): | |
| """Get model and raise NotImplementedError for loading cfg files.""" | |
| raise NotImplementedError("This task trainer doesn't support loading cfg files") | |
| def get_validator(self): | |
| """Returns a NotImplementedError when the get_validator function is called.""" | |
| raise NotImplementedError("get_validator function not implemented in trainer") | |
| def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"): | |
| """Returns dataloader derived from torch.data.Dataloader.""" | |
| raise NotImplementedError("get_dataloader function not implemented in trainer") | |
| def build_dataset(self, img_path, mode="train", batch=None): | |
| """Build dataset.""" | |
| raise NotImplementedError("build_dataset function not implemented in trainer") | |
| def label_loss_items(self, loss_items=None, prefix="train"): | |
| """ | |
| Returns a loss dict with labelled training loss items tensor. | |
| Note: | |
| This is not needed for classification but necessary for segmentation & detection | |
| """ | |
| return {"loss": loss_items} if loss_items is not None else ["loss"] | |
| def set_model_attributes(self): | |
| """To set or update model parameters before training.""" | |
| self.model.names = self.data["names"] | |
| def build_targets(self, preds, targets): | |
| """Builds target tensors for training YOLO model.""" | |
| pass | |
| def progress_string(self): | |
| """Returns a string describing training progress.""" | |
| return "" | |
| # TODO: may need to put these following functions into callback | |
| def plot_training_samples(self, batch, ni): | |
| """Plots training samples during YOLO training.""" | |
| pass | |
| def plot_training_labels(self): | |
| """Plots training labels for YOLO model.""" | |
| pass | |
| def save_metrics(self, metrics): | |
| """Saves training metrics to a CSV file.""" | |
| keys, vals = list(metrics.keys()), list(metrics.values()) | |
| n = len(metrics) + 1 # number of cols | |
| s = "" if self.csv.exists() else (("%23s," * n % tuple(["epoch"] + keys)).rstrip(",") + "\n") # header | |
| with open(self.csv, "a") as f: | |
| f.write(s + ("%23.5g," * n % tuple([self.epoch + 1] + vals)).rstrip(",") + "\n") | |
| def plot_metrics(self): | |
| """Plot and display metrics visually.""" | |
| pass | |
| def on_plot(self, name, data=None): | |
| """Registers plots (e.g. to be consumed in callbacks)""" | |
| path = Path(name) | |
| self.plots[path] = {"data": data, "timestamp": time.time()} | |
| def final_eval(self): | |
| """Performs final evaluation and validation for object detection YOLO model.""" | |
| for f in self.last, self.best: | |
| if f.exists(): | |
| strip_optimizer(f) # strip optimizers | |
| if f is self.best: | |
| LOGGER.info(f"\nValidating {f}...") | |
| self.validator.args.plots = self.args.plots | |
| self.metrics = self.validator(model=f) | |
| self.metrics.pop("fitness", None) | |
| self.run_callbacks("on_fit_epoch_end") | |
| def check_resume(self, overrides): | |
| """Check if resume checkpoint exists and update arguments accordingly.""" | |
| resume = self.args.resume | |
| if resume: | |
| try: | |
| exists = isinstance(resume, (str, Path)) and Path(resume).exists() | |
| last = Path(check_file(resume) if exists else get_latest_run()) | |
| # Check that resume data YAML exists, otherwise strip to force re-download of dataset | |
| ckpt_args = attempt_load_weights(last).args | |
| if not Path(ckpt_args["data"]).exists(): | |
| ckpt_args["data"] = self.args.data | |
| resume = True | |
| self.args = get_cfg(ckpt_args) | |
| self.args.model = self.args.resume = str(last) # reinstate model | |
| for k in "imgsz", "batch", "device": # allow arg updates to reduce memory or update device on resume | |
| if k in overrides: | |
| setattr(self.args, k, overrides[k]) | |
| except Exception as e: | |
| raise FileNotFoundError( | |
| "Resume checkpoint not found. Please pass a valid checkpoint to resume from, " | |
| "i.e. 'yolo train resume model=path/to/last.pt'" | |
| ) from e | |
| self.resume = resume | |
| def resume_training(self, ckpt): | |
| """Resume YOLO training from given epoch and best fitness.""" | |
| if ckpt is None or not self.resume: | |
| return | |
| best_fitness = 0.0 | |
| start_epoch = ckpt["epoch"] + 1 | |
| if ckpt["optimizer"] is not None: | |
| self.optimizer.load_state_dict(ckpt["optimizer"]) # optimizer | |
| best_fitness = ckpt["best_fitness"] | |
| if self.ema and ckpt.get("ema"): | |
| self.ema.ema.load_state_dict(ckpt["ema"].float().state_dict()) # EMA | |
| self.ema.updates = ckpt["updates"] | |
| assert start_epoch > 0, ( | |
| f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n" | |
| f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'" | |
| ) | |
| LOGGER.info(f"Resuming training {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs") | |
| if self.epochs < start_epoch: | |
| LOGGER.info( | |
| f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs." | |
| ) | |
| self.epochs += ckpt["epoch"] # finetune additional epochs | |
| self.best_fitness = best_fitness | |
| self.start_epoch = start_epoch | |
| if start_epoch > (self.epochs - self.args.close_mosaic): | |
| self._close_dataloader_mosaic() | |
| def _close_dataloader_mosaic(self): | |
| """Update dataloaders to stop using mosaic augmentation.""" | |
| if hasattr(self.train_loader.dataset, "mosaic"): | |
| self.train_loader.dataset.mosaic = False | |
| if hasattr(self.train_loader.dataset, "close_mosaic"): | |
| LOGGER.info("Closing dataloader mosaic") | |
| self.train_loader.dataset.close_mosaic(hyp=self.args) | |
| def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5): | |
| """ | |
| Constructs an optimizer for the given model, based on the specified optimizer name, learning rate, momentum, | |
| weight decay, and number of iterations. | |
| Args: | |
| model (torch.nn.Module): The model for which to build an optimizer. | |
| name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected | |
| based on the number of iterations. Default: 'auto'. | |
| lr (float, optional): The learning rate for the optimizer. Default: 0.001. | |
| momentum (float, optional): The momentum factor for the optimizer. Default: 0.9. | |
| decay (float, optional): The weight decay for the optimizer. Default: 1e-5. | |
| iterations (float, optional): The number of iterations, which determines the optimizer if | |
| name is 'auto'. Default: 1e5. | |
| Returns: | |
| (torch.optim.Optimizer): The constructed optimizer. | |
| """ | |
| g = [], [], [] # optimizer parameter groups | |
| bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d() | |
| if name == "auto": | |
| LOGGER.info( | |
| f"{colorstr('optimizer:')} 'optimizer=auto' found, " | |
| f"ignoring 'lr0={self.args.lr0}' and 'momentum={self.args.momentum}' and " | |
| f"determining best 'optimizer', 'lr0' and 'momentum' automatically... " | |
| ) | |
| nc = getattr(model, "nc", 10) # number of classes | |
| lr_fit = round(0.002 * 5 / (4 + nc), 6) # lr0 fit equation to 6 decimal places | |
| name, lr, momentum = ("SGD", 0.01, 0.9) if iterations > 10000 else ("AdamW", lr_fit, 0.9) | |
| self.args.warmup_bias_lr = 0.0 # no higher than 0.01 for Adam | |
| for module_name, module in model.named_modules(): | |
| for param_name, param in module.named_parameters(recurse=False): | |
| fullname = f"{module_name}.{param_name}" if module_name else param_name | |
| if "bias" in fullname: # bias (no decay) | |
| g[2].append(param) | |
| elif isinstance(module, bn): # weight (no decay) | |
| g[1].append(param) | |
| else: # weight (with decay) | |
| g[0].append(param) | |
| if name in ("Adam", "Adamax", "AdamW", "NAdam", "RAdam"): | |
| optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0) | |
| elif name == "RMSProp": | |
| optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum) | |
| elif name == "SGD": | |
| optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True) | |
| else: | |
| raise NotImplementedError( | |
| f"Optimizer '{name}' not found in list of available optimizers " | |
| f"[Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto]." | |
| "To request support for addition optimizers please visit https://github.com/ultralytics/ultralytics." | |
| ) | |
| optimizer.add_param_group({"params": g[0], "weight_decay": decay}) # add g0 with weight_decay | |
| optimizer.add_param_group({"params": g[1], "weight_decay": 0.0}) # add g1 (BatchNorm2d weights) | |
| LOGGER.info( | |
| f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups " | |
| f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)' | |
| ) | |
| return optimizer | |