| import csv, re, os | |
| from pytorch_lightning.loggers import TensorBoardLogger | |
| def read_csv(fname): | |
| with open(fname, 'r') as f: | |
| reader = csv.DictReader(f) | |
| return list(reader) | |
| def append_csv(fname, dicts): | |
| if isinstance(dicts, dict): | |
| dicts = [dicts] | |
| if os.path.isfile(fname): | |
| dicts = read_csv(fname) + dicts | |
| write_csv(fname, dicts) | |
| def write_csv(fname, dicts): | |
| assert len(dicts) > 0 | |
| with open(fname, 'w', newline='') as f: | |
| writer = csv.DictWriter(f, fieldnames=dicts[0].keys()) | |
| writer.writeheader() | |
| for d in dicts: | |
| writer.writerow(d) | |
| def now(): | |
| from datetime import datetime | |
| return datetime.now().strftime('%Y-%m-%d_%H-%M-%S') | |
| def get_info(weights: str): | |
| search = re.search(r"(.*)_epoch=(\d+)-step", weights) | |
| if search: | |
| name, epoch = search.groups() | |
| return str(name).split(os.sep)[-1], str(epoch) | |
| return None, None | |
| def get_matlist(cache_dir, dir): | |
| with open(cache_dir, 'r') as f: | |
| content = f.readlines() | |
| files = [dir/f.strip() for f in content] | |
| return files | |
| def get_logger(args): | |
| logger = TensorBoardLogger(save_dir=args.out_dir) | |
| logger.log_hyperparams(args) | |