#!/usr/bin/env python3 # -*- coding: utf-8 -*- """CLI helpers shared across experiment scripts.""" from __future__ import annotations import argparse import logging from .io_utils import DEFAULT_MANIFEST_FILE, DEFAULT_OUTPUT_ROOT def add_base_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: """Base runtime arguments applicable to every experiment.""" parser.add_argument("--run_name", required=True, help="Name of this run (used for output directory).") parser.add_argument("--output_root", default=str(DEFAULT_OUTPUT_ROOT), help="Root directory for all run outputs.") parser.add_argument("--manifest_file", default=str(DEFAULT_MANIFEST_FILE), help="Path to dataset_manifests.json.") parser.add_argument("--smoke", action="store_true", help="Run in smoke-test mode (small sample).") parser.add_argument("--seed", type=int, default=42, help="Random seed.") parser.add_argument("--device", default="auto", help="Device for torch models (auto/cpu/cuda).") parser.add_argument("--log_level", default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR"], help="Console log level.") return parser def add_train_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: """Training hyper-parameters commonly overridden across experiments.""" parser.add_argument("--model_id", default=None, help="HuggingFace model identifier (e.g. hfl/chinese-bert-wwm-ext).") parser.add_argument("--epochs", type=int, default=None, help="Number of training epochs.") parser.add_argument("--batch_size", type=int, default=None, help="Training batch size.") parser.add_argument("--eval_batch_size", type=int, default=None, help="Evaluation batch size.") parser.add_argument("--max_len", type=int, default=None, help="Maximum sequence length (tokenizer).") parser.add_argument("--lr", type=float, default=None, help="Learning rate.") parser.add_argument("--weight_decay", type=float, default=None, help="AdamW weight decay.") parser.add_argument("--warmup_ratio", type=float, default=None, help="Warmup ratio for linear scheduler.") parser.add_argument("--grad_acc", type=int, default=None, help="Gradient accumulation steps.") parser.add_argument("--early_stopping_patience", type=int, default=None, help="Early stopping patience (epochs).") parser.add_argument("--use_amp", action="store_true", default=None, help="Enable automatic mixed precision (AMP).") parser.add_argument("--no_amp", action="store_true", default=None, help="Disable automatic mixed precision (AMP).") return parser def setup_logging(level: str | int = logging.INFO) -> None: """Configure root logger with a consistent format.""" if isinstance(level, str): level = getattr(logging, level.upper(), logging.INFO) logging.basicConfig( level=level, format="%(asctime)s | %(levelname)-8s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) def resolve_arg(val, fallback): """Return CLI value if explicitly provided, otherwise fallback.""" return val if val is not None else fallback