""" Common utilities: logging, jsonl I/O, atomic write, VRAM check. """ import json import logging import os import sys import gc from pathlib import Path from datetime import datetime def setup_logger(name: str, log_file: Path = None, level=logging.INFO): """ Configure logger to stdout + optional file. Returns logger. Removes existing handlers to avoid duplicate logs on re-import. """ logger = logging.getLogger(name) logger.setLevel(level) logger.handlers = [] logger.propagate = False fmt = logging.Formatter( "[%(asctime)s] %(levelname)s %(name)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) sh = logging.StreamHandler(sys.stdout) sh.setFormatter(fmt) logger.addHandler(sh) if log_file is not None: log_file = Path(log_file) log_file.parent.mkdir(parents=True, exist_ok=True) fh = logging.FileHandler(log_file, mode="a", encoding="utf-8") fh.setFormatter(fmt) logger.addHandler(fh) return logger def atomic_write_bytes(path: Path, data: bytes): """Atomic file write: write to .tmp then rename.""" path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) tmp = path.with_suffix(path.suffix + ".tmp") with open(tmp, "wb") as f: f.write(data) os.replace(tmp, path) def write_jsonl(records, path: Path): path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) tmp = path.with_suffix(path.suffix + ".tmp") with open(tmp, "w", encoding="utf-8") as f: for r in records: f.write(json.dumps(r, ensure_ascii=False) + "\n") os.replace(tmp, path) def read_jsonl(path: Path): path = Path(path) records = [] with open(path, "r", encoding="utf-8") as f: for line in f: line = line.strip() if line: records.append(json.loads(line)) return records def append_jsonl(record, path: Path): path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) with open(path, "a", encoding="utf-8") as f: f.write(json.dumps(record, ensure_ascii=False) + "\n") def write_json(obj, path: Path, indent=2): path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) tmp = path.with_suffix(path.suffix + ".tmp") with open(tmp, "w", encoding="utf-8") as f: json.dump(obj, f, ensure_ascii=False, indent=indent) os.replace(tmp, path) def read_json(path: Path): with open(path, "r", encoding="utf-8") as f: return json.load(f) def get_vram_mb(): try: import torch if torch.cuda.is_available(): return torch.cuda.memory_allocated() / 1024**2 except Exception: pass return 0.0 def cleanup_memory(): """Call after each heavy step.""" gc.collect() try: import torch if torch.cuda.is_available(): torch.cuda.empty_cache() except Exception: pass def ts(): return datetime.now().strftime("%Y-%m-%d %H:%M:%S") def compute_completed_ids(path: Path): """For resume: read existing jsonl and return set of completed `idx` values.""" path = Path(path) if not path.exists(): return set() done = set() with open(path, "r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue try: obj = json.loads(line) if "idx" in obj: done.add(obj["idx"]) except Exception: continue return done