from __future__ import annotations import os, json, math, random, time, logging, sys from dataclasses import dataclass from typing import Iterable, List, Dict, Any from . import config as CFG _LOGGER = None def get_logger(): global _LOGGER if _LOGGER is not None: return _LOGGER os.makedirs(CFG.OUTPUT_DIR, exist_ok=True) log_path = os.path.join(CFG.OUTPUT_DIR, 'train.log') logger = logging.getLogger('bpe') logger.setLevel(logging.INFO) formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s') fh = logging.FileHandler(log_path, encoding='utf-8') fh.setFormatter(formatter) sh = logging.StreamHandler(sys.stdout) sh.setFormatter(formatter) logger.addHandler(fh) logger.addHandler(sh) _LOGGER = logger return logger def save_json(path: str, obj: Any): with open(path, 'w', encoding='utf-8') as f: json.dump(obj, f, ensure_ascii=False, indent=2) def load_json(path: str): with open(path, 'r', encoding='utf-8') as f: return json.load(f) def set_seed(seed: int): random.seed(seed) def chunks(iterable: Iterable, size: int): bucket = [] for x in iterable: bucket.append(x) if len(bucket) >= size: yield bucket bucket = [] if bucket: yield bucket class Timer: def __init__(self): self.start = time.time() def elapsed(self): return time.time() - self.start