LUCIFerace's picture
Add files using upload-large-folder tool
4a0f6a5 verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""CLI helpers shared across experiment scripts."""
from __future__ import annotations
import argparse
import logging
from .io_utils import DEFAULT_MANIFEST_FILE, DEFAULT_OUTPUT_ROOT
def add_base_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
"""Base runtime arguments applicable to every experiment."""
parser.add_argument("--run_name", required=True, help="Name of this run (used for output directory).")
parser.add_argument("--output_root", default=str(DEFAULT_OUTPUT_ROOT), help="Root directory for all run outputs.")
parser.add_argument("--manifest_file", default=str(DEFAULT_MANIFEST_FILE), help="Path to dataset_manifests.json.")
parser.add_argument("--smoke", action="store_true", help="Run in smoke-test mode (small sample).")
parser.add_argument("--seed", type=int, default=42, help="Random seed.")
parser.add_argument("--device", default="auto", help="Device for torch models (auto/cpu/cuda).")
parser.add_argument("--log_level", default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR"],
help="Console log level.")
return parser
def add_train_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
"""Training hyper-parameters commonly overridden across experiments."""
parser.add_argument("--model_id", default=None, help="HuggingFace model identifier (e.g. hfl/chinese-bert-wwm-ext).")
parser.add_argument("--epochs", type=int, default=None, help="Number of training epochs.")
parser.add_argument("--batch_size", type=int, default=None, help="Training batch size.")
parser.add_argument("--eval_batch_size", type=int, default=None, help="Evaluation batch size.")
parser.add_argument("--max_len", type=int, default=None, help="Maximum sequence length (tokenizer).")
parser.add_argument("--lr", type=float, default=None, help="Learning rate.")
parser.add_argument("--weight_decay", type=float, default=None, help="AdamW weight decay.")
parser.add_argument("--warmup_ratio", type=float, default=None, help="Warmup ratio for linear scheduler.")
parser.add_argument("--grad_acc", type=int, default=None, help="Gradient accumulation steps.")
parser.add_argument("--early_stopping_patience", type=int, default=None, help="Early stopping patience (epochs).")
parser.add_argument("--use_amp", action="store_true", default=None, help="Enable automatic mixed precision (AMP).")
parser.add_argument("--no_amp", action="store_true", default=None, help="Disable automatic mixed precision (AMP).")
return parser
def setup_logging(level: str | int = logging.INFO) -> None:
"""Configure root logger with a consistent format."""
if isinstance(level, str):
level = getattr(logging, level.upper(), logging.INFO)
logging.basicConfig(
level=level,
format="%(asctime)s | %(levelname)-8s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
def resolve_arg(val, fallback):
"""Return CLI value if explicitly provided, otherwise fallback."""
return val if val is not None else fallback