| import os |
| import sys |
| import time |
| from datetime import datetime |
| import logging |
| import numpy as np |
| import torch |
| import math |
|
|
| def get_timestamp(): |
| return datetime.now().strftime('%y%m%d-%H%M%S') |
|
|
| def mkdir_and_rename(path): |
| if os.path.exists(path): |
| new_name = path + '_archived_' + get_timestamp() |
| print('Path already exists. Rename it to [{:s}]'.format(new_name)) |
| os.rename(path, new_name) |
| os.makedirs(path) |
|
|
|
|
| def scandir(dir_path, suffix=None, recursive=False, full_path=False): |
| """Scan a directory to find the interested files. |
| Args: |
| dir_path (str): Path of the directory. |
| suffix (str | tuple(str), optional): File suffix that we are |
| interested in. Default: None. |
| recursive (bool, optional): If set to True, recursively scan the |
| directory. Default: False. |
| full_path (bool, optional): If set to True, include the dir_path. |
| Default: False. |
| Returns: |
| A generator for all the interested files with relative pathes. |
| """ |
|
|
| if (suffix is not None) and not isinstance(suffix, (str, tuple)): |
| raise TypeError('"suffix" must be a string or tuple of strings') |
|
|
| root = dir_path |
|
|
| def _scandir(dir_path, suffix, recursive): |
| for entry in os.scandir(dir_path): |
| if not entry.name.startswith('.') and entry.is_file(): |
| if full_path: |
| return_path = entry.path |
| else: |
| return_path = os.path.relpath(entry.path, root) |
|
|
| if suffix is None: |
| yield return_path |
| elif return_path.endswith(suffix): |
| yield return_path |
| else: |
| if recursive: |
| yield from _scandir( |
| entry.path, suffix=suffix, recursive=recursive) |
| else: |
| continue |
|
|
| return _scandir(dir_path, suffix=suffix, recursive=recursive) |
|
|
|
|
| def setup_logger(log_file_path): |
| log_formatter = logging.Formatter("%(asctime)s [%(levelname)-5.5s] %(message)s") |
| root_logger = logging.getLogger() |
| root_logger.setLevel(logging.INFO) |
|
|
| log_file_handler = logging.FileHandler(log_file_path, encoding='utf-8') |
| log_file_handler.setFormatter(log_formatter) |
| root_logger.addHandler(log_file_handler) |
|
|
| log_stream_handler = logging.StreamHandler(sys.stdout) |
| log_stream_handler.setFormatter(log_formatter) |
| root_logger.addHandler(log_stream_handler) |
|
|
| logging.info('Logging file is %s' % log_file_path) |
|
|
|
|
| def print_args(args): |
| for arg in vars(args): |
| logging.info(arg + ':%s'%(getattr(args, arg))) |