File size: 3,130 Bytes
4a0f6a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""CLI helpers shared across experiment scripts."""
from __future__ import annotations

import argparse
import logging

from .io_utils import DEFAULT_MANIFEST_FILE, DEFAULT_OUTPUT_ROOT


def add_base_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
    """Base runtime arguments applicable to every experiment."""
    parser.add_argument("--run_name", required=True, help="Name of this run (used for output directory).")
    parser.add_argument("--output_root", default=str(DEFAULT_OUTPUT_ROOT), help="Root directory for all run outputs.")
    parser.add_argument("--manifest_file", default=str(DEFAULT_MANIFEST_FILE), help="Path to dataset_manifests.json.")
    parser.add_argument("--smoke", action="store_true", help="Run in smoke-test mode (small sample).")
    parser.add_argument("--seed", type=int, default=42, help="Random seed.")
    parser.add_argument("--device", default="auto", help="Device for torch models (auto/cpu/cuda).")
    parser.add_argument("--log_level", default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR"],
                        help="Console log level.")
    return parser


def add_train_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
    """Training hyper-parameters commonly overridden across experiments."""
    parser.add_argument("--model_id", default=None, help="HuggingFace model identifier (e.g. hfl/chinese-bert-wwm-ext).")
    parser.add_argument("--epochs", type=int, default=None, help="Number of training epochs.")
    parser.add_argument("--batch_size", type=int, default=None, help="Training batch size.")
    parser.add_argument("--eval_batch_size", type=int, default=None, help="Evaluation batch size.")
    parser.add_argument("--max_len", type=int, default=None, help="Maximum sequence length (tokenizer).")
    parser.add_argument("--lr", type=float, default=None, help="Learning rate.")
    parser.add_argument("--weight_decay", type=float, default=None, help="AdamW weight decay.")
    parser.add_argument("--warmup_ratio", type=float, default=None, help="Warmup ratio for linear scheduler.")
    parser.add_argument("--grad_acc", type=int, default=None, help="Gradient accumulation steps.")
    parser.add_argument("--early_stopping_patience", type=int, default=None, help="Early stopping patience (epochs).")
    parser.add_argument("--use_amp", action="store_true", default=None, help="Enable automatic mixed precision (AMP).")
    parser.add_argument("--no_amp", action="store_true", default=None, help="Disable automatic mixed precision (AMP).")
    return parser


def setup_logging(level: str | int = logging.INFO) -> None:
    """Configure root logger with a consistent format."""
    if isinstance(level, str):
        level = getattr(logging, level.upper(), logging.INFO)
    logging.basicConfig(
        level=level,
        format="%(asctime)s | %(levelname)-8s | %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )


def resolve_arg(val, fallback):
    """Return CLI value if explicitly provided, otherwise fallback."""
    return val if val is not None else fallback