Spaces:
Sleeping
Sleeping
| """ | |
| RT-DETRv4: Painlessly Furthering Real-Time Object Detection with Vision Foundation Models | |
| Copyright (c) 2025 The RT-DETRv4 Authors. All Rights Reserved. | |
| --------------------------------------------------------------------------------- | |
| Modified from DEIM: DETR with Improved Matching for Fast Convergence | |
| Copyright (c) 2024 The DEIM Authors. All Rights Reserved. | |
| """ | |
| import time | |
| import json | |
| import datetime | |
| import math | |
| import torch | |
| from ..misc import dist_utils, stats | |
| from ._solver import BaseSolver | |
| from .det_engine import train_one_epoch, evaluate | |
| from ..optim.lr_scheduler import FlatCosineLRScheduler | |
| class DetSolver(BaseSolver): | |
| def fit(self, ): | |
| self.train() | |
| args = self.cfg | |
| n_parameters, model_stats = stats(self.cfg) | |
| print(model_stats) | |
| print("-"*42 + "Start training" + "-"*43) | |
| self.self_lr_scheduler = False | |
| if args.lrsheduler is not None: | |
| iter_per_epoch = len(self.train_dataloader) | |
| print(" ## Using Self-defined Scheduler-{} ## ".format(args.lrsheduler)) | |
| self.lr_scheduler = FlatCosineLRScheduler(self.optimizer, args.lr_gamma, iter_per_epoch, total_epochs=args.epoches, | |
| warmup_iter=args.warmup_iter, flat_epochs=args.flat_epoch, no_aug_epochs=args.no_aug_epoch) | |
| self.self_lr_scheduler = True | |
| n_parameters = sum([p.numel() for p in self.model.parameters() if p.requires_grad]) | |
| print(f'number of trainable parameters: {n_parameters}') | |
| top1 = 0 | |
| best_stat = {'epoch': -1, } | |
| # evaluate again before resume training | |
| if self.last_epoch > 0: | |
| module = self.ema.module if self.ema else self.model | |
| test_stats, coco_evaluator = evaluate( | |
| module, | |
| self.criterion, | |
| self.postprocessor, | |
| self.val_dataloader, | |
| self.evaluator, | |
| self.device | |
| ) | |
| for k in test_stats: | |
| best_stat['epoch'] = self.last_epoch | |
| best_stat[k] = test_stats[k][0] | |
| top1 = test_stats[k][0] | |
| print(f'best_stat: {best_stat}') | |
| best_stat_print = best_stat.copy() | |
| start_time = time.time() | |
| start_epoch = self.last_epoch + 1 | |
| for epoch in range(start_epoch, args.epoches): | |
| self.train_dataloader.set_epoch(epoch) | |
| # self.train_dataloader.dataset.set_epoch(epoch) | |
| if dist_utils.is_dist_available_and_initialized(): | |
| self.train_dataloader.sampler.set_epoch(epoch) | |
| if epoch == self.train_dataloader.collate_fn.stop_epoch: | |
| self.load_resume_state(str(self.output_dir / 'best_stg1.pth')) | |
| self.ema.decay = self.train_dataloader.collate_fn.ema_restart_decay | |
| print(f'Refresh EMA at epoch {epoch} with decay {self.ema.decay}') | |
| train_stats, grad_percentages = train_one_epoch( | |
| self.self_lr_scheduler, | |
| self.lr_scheduler, | |
| self.model, | |
| self.criterion, | |
| self.train_dataloader, | |
| self.optimizer, | |
| self.device, | |
| epoch, | |
| max_norm=args.clip_max_norm, | |
| print_freq=args.print_freq, | |
| ema=self.ema, | |
| scaler=self.scaler, | |
| lr_warmup_scheduler=self.lr_warmup_scheduler, | |
| writer=self.writer, | |
| teacher_model=self.teacher_model, # NEW: Pass teacher model to train_one_epoch | |
| ) | |
| if not self.self_lr_scheduler: # update by epoch | |
| if self.lr_warmup_scheduler is None or self.lr_warmup_scheduler.finished(): | |
| self.lr_scheduler.step() | |
| self.last_epoch += 1 | |
| if dist_utils.is_main_process() and hasattr(self.criterion, 'distill_adaptive_params') and \ | |
| self.criterion.distill_adaptive_params and self.criterion.distill_adaptive_params.get('enabled', False): | |
| params = self.criterion.distill_adaptive_params | |
| default_weight = params.get('default_weight') | |
| avg_percentage = sum(grad_percentages) / len(grad_percentages) if grad_percentages else 0.0 | |
| current_weight = self.criterion.weight_dict.get('loss_distill', 0.0) | |
| new_weight = current_weight | |
| reason = 'unchanged' | |
| if avg_percentage < 1e-6: | |
| if default_weight is not None: | |
| new_weight = default_weight | |
| reason = 'reset_to_default_zero_grad' | |
| elif epoch >= self.train_dataloader.collate_fn.stop_epoch: | |
| if default_weight is not None: | |
| new_weight = default_weight | |
| reason = 'ema_phase_default' | |
| else: | |
| rho = params['rho'] | |
| delta = params['delta'] | |
| lower_bound = rho - delta | |
| upper_bound = rho + delta | |
| if not (lower_bound <= avg_percentage <= upper_bound): | |
| target_percentage = upper_bound if avg_percentage < lower_bound else lower_bound | |
| if current_weight > 1e-6: | |
| p_current = avg_percentage / 100.0 | |
| p_target = target_percentage / 100.0 | |
| numerator = p_target * (1.0 - p_current) | |
| denominator = p_current * (1.0 - p_target) | |
| if abs(denominator) >= 1e-9: | |
| ratio = numerator / denominator | |
| ratio = max(ratio, 0.1) # clamp non-positive to 0.1 | |
| new_weight = current_weight * ratio | |
| new_weight = min(max(new_weight, current_weight / 10.0), current_weight * 10.0) | |
| reason = f'adjusted_to_{target_percentage:.2f}%' | |
| if abs(new_weight - current_weight) > 0: | |
| self.criterion.weight_dict['loss_distill'] = new_weight | |
| print(f"Epoch {epoch}: avg encoder grad {avg_percentage:.2f}% | distill {current_weight:.6f} -> {new_weight:.6f} ({reason})") | |
| if self.output_dir and epoch < self.train_dataloader.collate_fn.stop_epoch: | |
| checkpoint_paths = [self.output_dir / 'last.pth'] | |
| # extra checkpoint before LR drop and every 100 epochs | |
| if (epoch + 1) % args.checkpoint_freq == 0: | |
| checkpoint_paths.append(self.output_dir / f'checkpoint{epoch:04}.pth') | |
| for checkpoint_path in checkpoint_paths: | |
| dist_utils.save_on_master(self.state_dict(), checkpoint_path) | |
| module = self.ema.module if self.ema else self.model | |
| test_stats, coco_evaluator = evaluate( | |
| module, | |
| self.criterion, | |
| self.postprocessor, | |
| self.val_dataloader, | |
| self.evaluator, | |
| self.device | |
| ) | |
| # TODO | |
| for k in test_stats: | |
| if self.writer and dist_utils.is_main_process(): | |
| for i, v in enumerate(test_stats[k]): | |
| self.writer.add_scalar(f'Test/{k}_{i}'.format(k), v, epoch) | |
| if k in best_stat: | |
| best_stat['epoch'] = epoch if test_stats[k][0] > best_stat[k] else best_stat['epoch'] | |
| best_stat[k] = max(best_stat[k], test_stats[k][0]) | |
| else: | |
| best_stat['epoch'] = epoch | |
| best_stat[k] = test_stats[k][0] | |
| if best_stat[k] > top1: | |
| best_stat_print['epoch'] = epoch | |
| top1 = best_stat[k] | |
| if self.output_dir: | |
| if epoch >= self.train_dataloader.collate_fn.stop_epoch: | |
| dist_utils.save_on_master(self.state_dict(), self.output_dir / 'best_stg2.pth') | |
| else: | |
| dist_utils.save_on_master(self.state_dict(), self.output_dir / 'best_stg1.pth') | |
| best_stat_print[k] = max(best_stat[k], top1) | |
| print(f'best_stat: {best_stat_print}') # global best | |
| if best_stat['epoch'] == epoch and self.output_dir: | |
| if epoch >= self.train_dataloader.collate_fn.stop_epoch: | |
| if test_stats[k][0] > top1: | |
| top1 = test_stats[k][0] | |
| dist_utils.save_on_master(self.state_dict(), self.output_dir / 'best_stg2.pth') | |
| else: | |
| top1 = max(test_stats[k][0], top1) | |
| dist_utils.save_on_master(self.state_dict(), self.output_dir / 'best_stg1.pth') | |
| elif epoch >= self.train_dataloader.collate_fn.stop_epoch: | |
| best_stat = {'epoch': -1, } | |
| self.ema.decay -= 0.0001 | |
| self.load_resume_state(str(self.output_dir / 'best_stg1.pth')) | |
| print(f'Refresh EMA at epoch {epoch} with decay {self.ema.decay}') | |
| log_stats = { | |
| **{f'train_{k}': v for k, v in train_stats.items()}, | |
| **{f'test_{k}': v for k, v in test_stats.items()}, | |
| 'epoch': epoch, | |
| 'n_parameters': n_parameters | |
| } | |
| if self.output_dir and dist_utils.is_main_process(): | |
| with (self.output_dir / "log.txt").open("a") as f: | |
| f.write(json.dumps(log_stats) + "\n") | |
| # for evaluation logs | |
| if coco_evaluator is not None: | |
| (self.output_dir / 'eval').mkdir(exist_ok=True) | |
| if "bbox" in coco_evaluator.coco_eval: | |
| filenames = ['latest.pth'] | |
| if epoch % 50 == 0: | |
| filenames.append(f'{epoch:03}.pth') | |
| for name in filenames: | |
| torch.save(coco_evaluator.coco_eval["bbox"].eval, | |
| self.output_dir / "eval" / name) | |
| total_time = time.time() - start_time | |
| total_time_str = str(datetime.timedelta(seconds=int(total_time))) | |
| print('Training time {}'.format(total_time_str)) | |
| def val(self, ): | |
| self.eval() | |
| module = self.ema.module if self.ema else self.model | |
| test_stats, coco_evaluator = evaluate(module, self.criterion, self.postprocessor, | |
| self.val_dataloader, self.evaluator, self.device) | |
| if self.output_dir: | |
| dist_utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval, self.output_dir / "eval.pth") | |
| return | |
| def state_dict(self): | |
| """State dict, train/eval""" | |
| state = {} | |
| state['date'] = datetime.datetime.now().isoformat() | |
| # For resume | |
| state['last_epoch'] = self.last_epoch | |
| for k, v in self.__dict__.items(): | |
| if k == 'teacher_model': | |
| continue | |
| if hasattr(v, 'state_dict'): | |
| v = dist_utils.de_parallel(v) | |
| state[k] = v.state_dict() | |
| return state |