Spaces:
Running on Zero
Running on Zero
| ''' | |
| by lyuwenyu | |
| ''' | |
| import time | |
| import json | |
| import datetime | |
| import torch | |
| from src.misc import dist | |
| from src.data import get_coco_api_from_dataset | |
| from .solver import BaseSolver | |
| from .det_engine import train_one_epoch, evaluate | |
| class DetSolver(BaseSolver): | |
| def fit(self, ): | |
| print("Start training") | |
| self.train() | |
| args = self.cfg | |
| n_parameters = sum(p.numel() for p in self.model.parameters() if p.requires_grad) | |
| print('number of params:', n_parameters) | |
| base_ds = get_coco_api_from_dataset(self.val_dataloader.dataset) | |
| # best_stat = {'coco_eval_bbox': 0, 'coco_eval_masks': 0, 'epoch': -1, } | |
| best_stat = {'epoch': -1, } | |
| save_last_only = args.yaml_cfg.get('save_last_only', False) | |
| save_latest_every_epoch = args.yaml_cfg.get('save_latest_every_epoch', False) | |
| save_best = args.yaml_cfg.get('save_best', False) | |
| save_best_key = args.yaml_cfg.get('save_best_key', 'coco_eval_bbox') | |
| save_best_key_index = int(args.yaml_cfg.get('save_best_key_index', 0)) | |
| best_value = float('-inf') | |
| start_time = time.time() | |
| for epoch in range(self.last_epoch + 1, args.epoches): | |
| if dist.is_dist_available_and_initialized(): | |
| self.train_dataloader.sampler.set_epoch(epoch) | |
| model_ref = self.model.module if dist.is_parallel(self.model) else self.model | |
| if hasattr(model_ref, "encoder") and hasattr(model_ref.encoder, "set_fog_gate_epoch"): | |
| model_ref.encoder.set_fog_gate_epoch(epoch) | |
| if hasattr(model_ref, "encoder") and hasattr(model_ref.encoder, "set_spfm_epoch"): | |
| model_ref.encoder.set_spfm_epoch(epoch) | |
| if hasattr(model_ref, "set_deb_epoch"): | |
| model_ref.set_deb_epoch(epoch) | |
| train_stats = train_one_epoch( | |
| self.model, self.criterion, self.train_dataloader, self.optimizer, self.device, epoch, | |
| args.clip_max_norm, print_freq=args.log_step, ema=self.ema, scaler=self.scaler) | |
| self.lr_scheduler.step() | |
| if self.output_dir: | |
| if save_last_only: | |
| if epoch == args.epoches - 1: | |
| dist.save_on_master(self.state_dict(epoch), self.output_dir / 'checkpoint.pth') | |
| else: | |
| checkpoint_paths = [] | |
| should_save_periodic = (epoch + 1) % args.checkpoint_step == 0 | |
| should_save_final = epoch == args.epoches - 1 | |
| if save_latest_every_epoch: | |
| checkpoint_paths.append(self.output_dir / 'checkpoint.pth') | |
| if should_save_periodic: | |
| checkpoint_paths.append(self.output_dir / f'checkpoint{epoch:04}.pth') | |
| elif should_save_periodic or should_save_final: | |
| checkpoint_paths.append(self.output_dir / 'checkpoint.pth') | |
| checkpoint_paths.append(self.output_dir / f'checkpoint{epoch:04}.pth') | |
| for checkpoint_path in checkpoint_paths: | |
| dist.save_on_master(self.state_dict(epoch), checkpoint_path) | |
| module = self.ema.module if self.ema else self.model | |
| test_stats, coco_evaluator = evaluate( | |
| module, self.criterion, self.postprocessor, self.val_dataloader, base_ds, self.device, self.output_dir | |
| ) | |
| # TODO | |
| for k in test_stats.keys(): | |
| if k in best_stat: | |
| best_stat['epoch'] = epoch if test_stats[k][0] > best_stat[k] else best_stat['epoch'] | |
| best_stat[k] = max(best_stat[k], test_stats[k][0]) | |
| else: | |
| best_stat['epoch'] = epoch | |
| best_stat[k] = test_stats[k][0] | |
| print('best_stat: ', best_stat) | |
| if save_best and self.output_dir and save_best_key in test_stats: | |
| v = test_stats[save_best_key] | |
| v = v[save_best_key_index] if isinstance(v, (list, tuple)) else float(v) | |
| if v > best_value: | |
| best_value = v | |
| dist.save_on_master(self.state_dict(epoch), self.output_dir / 'best.pth') | |
| log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, | |
| **{f'test_{k}': v for k, v in test_stats.items()}, | |
| 'epoch': epoch, | |
| 'n_parameters': n_parameters} | |
| if self.output_dir and dist.is_main_process(): | |
| with (self.output_dir / "log.txt").open("a") as f: | |
| f.write(json.dumps(log_stats) + "\n") | |
| # for evaluation logs | |
| if coco_evaluator is not None: | |
| (self.output_dir / 'eval').mkdir(exist_ok=True) | |
| if "bbox" in coco_evaluator.coco_eval: | |
| filenames = ['latest.pth'] | |
| if epoch % 50 == 0: | |
| filenames.append(f'{epoch:03}.pth') | |
| for name in filenames: | |
| torch.save(coco_evaluator.coco_eval["bbox"].eval, | |
| self.output_dir / "eval" / name) | |
| total_time = time.time() - start_time | |
| total_time_str = str(datetime.timedelta(seconds=int(total_time))) | |
| print('Training time {}'.format(total_time_str)) | |
| def val(self, ): | |
| self.eval() | |
| base_ds = get_coco_api_from_dataset(self.val_dataloader.dataset) | |
| module = self.ema.module if self.ema else self.model | |
| test_stats, coco_evaluator = evaluate(module, self.criterion, self.postprocessor, | |
| self.val_dataloader, base_ds, self.device, self.output_dir) | |
| if self.output_dir: | |
| dist.save_on_master(coco_evaluator.coco_eval["bbox"].eval, self.output_dir / "eval.pth") | |
| return | |