rtdetr_fog / src /solver /det_solver.py
miskey607218's picture
First commit of my fog detection app
2a25b9b
'''
by lyuwenyu
'''
import time
import json
import datetime
import torch
from src.misc import dist
from src.data import get_coco_api_from_dataset
from .solver import BaseSolver
from .det_engine import train_one_epoch, evaluate
class DetSolver(BaseSolver):
def fit(self, ):
print("Start training")
self.train()
args = self.cfg
n_parameters = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
print('number of params:', n_parameters)
base_ds = get_coco_api_from_dataset(self.val_dataloader.dataset)
# best_stat = {'coco_eval_bbox': 0, 'coco_eval_masks': 0, 'epoch': -1, }
best_stat = {'epoch': -1, }
save_last_only = args.yaml_cfg.get('save_last_only', False)
save_latest_every_epoch = args.yaml_cfg.get('save_latest_every_epoch', False)
save_best = args.yaml_cfg.get('save_best', False)
save_best_key = args.yaml_cfg.get('save_best_key', 'coco_eval_bbox')
save_best_key_index = int(args.yaml_cfg.get('save_best_key_index', 0))
best_value = float('-inf')
start_time = time.time()
for epoch in range(self.last_epoch + 1, args.epoches):
if dist.is_dist_available_and_initialized():
self.train_dataloader.sampler.set_epoch(epoch)
model_ref = self.model.module if dist.is_parallel(self.model) else self.model
if hasattr(model_ref, "encoder") and hasattr(model_ref.encoder, "set_fog_gate_epoch"):
model_ref.encoder.set_fog_gate_epoch(epoch)
if hasattr(model_ref, "encoder") and hasattr(model_ref.encoder, "set_spfm_epoch"):
model_ref.encoder.set_spfm_epoch(epoch)
if hasattr(model_ref, "set_deb_epoch"):
model_ref.set_deb_epoch(epoch)
train_stats = train_one_epoch(
self.model, self.criterion, self.train_dataloader, self.optimizer, self.device, epoch,
args.clip_max_norm, print_freq=args.log_step, ema=self.ema, scaler=self.scaler)
self.lr_scheduler.step()
if self.output_dir:
if save_last_only:
if epoch == args.epoches - 1:
dist.save_on_master(self.state_dict(epoch), self.output_dir / 'checkpoint.pth')
else:
checkpoint_paths = []
should_save_periodic = (epoch + 1) % args.checkpoint_step == 0
should_save_final = epoch == args.epoches - 1
if save_latest_every_epoch:
checkpoint_paths.append(self.output_dir / 'checkpoint.pth')
if should_save_periodic:
checkpoint_paths.append(self.output_dir / f'checkpoint{epoch:04}.pth')
elif should_save_periodic or should_save_final:
checkpoint_paths.append(self.output_dir / 'checkpoint.pth')
checkpoint_paths.append(self.output_dir / f'checkpoint{epoch:04}.pth')
for checkpoint_path in checkpoint_paths:
dist.save_on_master(self.state_dict(epoch), checkpoint_path)
module = self.ema.module if self.ema else self.model
test_stats, coco_evaluator = evaluate(
module, self.criterion, self.postprocessor, self.val_dataloader, base_ds, self.device, self.output_dir
)
# TODO
for k in test_stats.keys():
if k in best_stat:
best_stat['epoch'] = epoch if test_stats[k][0] > best_stat[k] else best_stat['epoch']
best_stat[k] = max(best_stat[k], test_stats[k][0])
else:
best_stat['epoch'] = epoch
best_stat[k] = test_stats[k][0]
print('best_stat: ', best_stat)
if save_best and self.output_dir and save_best_key in test_stats:
v = test_stats[save_best_key]
v = v[save_best_key_index] if isinstance(v, (list, tuple)) else float(v)
if v > best_value:
best_value = v
dist.save_on_master(self.state_dict(epoch), self.output_dir / 'best.pth')
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
**{f'test_{k}': v for k, v in test_stats.items()},
'epoch': epoch,
'n_parameters': n_parameters}
if self.output_dir and dist.is_main_process():
with (self.output_dir / "log.txt").open("a") as f:
f.write(json.dumps(log_stats) + "\n")
# for evaluation logs
if coco_evaluator is not None:
(self.output_dir / 'eval').mkdir(exist_ok=True)
if "bbox" in coco_evaluator.coco_eval:
filenames = ['latest.pth']
if epoch % 50 == 0:
filenames.append(f'{epoch:03}.pth')
for name in filenames:
torch.save(coco_evaluator.coco_eval["bbox"].eval,
self.output_dir / "eval" / name)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
def val(self, ):
self.eval()
base_ds = get_coco_api_from_dataset(self.val_dataloader.dataset)
module = self.ema.module if self.ema else self.model
test_stats, coco_evaluator = evaluate(module, self.criterion, self.postprocessor,
self.val_dataloader, base_ds, self.device, self.output_dir)
if self.output_dir:
dist.save_on_master(coco_evaluator.coco_eval["bbox"].eval, self.output_dir / "eval.pth")
return