#!/usr/bin/env python3 # -*- coding: utf-8 -*- """E01: BERT baseline training and evaluation. This script is fully self-contained. It loads DS01, trains a BERT classifier with a custom MLP head (champion-aligned), evaluates on dev/test, and writes predictions + metrics to the run directory. """ from __future__ import annotations import argparse import logging import sys from pathlib import Path import numpy as np import pandas as pd # Allow imports from source root package REPO_ROOT = Path(__file__).resolve() while REPO_ROOT != REPO_ROOT.parent and not (REPO_ROOT / "src").exists(): REPO_ROOT = REPO_ROOT.parent for _candidate in (REPO_ROOT, REPO_ROOT / "src"): _candidate_str = str(_candidate) if _candidate.exists() and _candidate_str not in sys.path: sys.path.insert(0, _candidate_str) from enhanced_replica.cli_args import add_base_args, add_train_args, resolve_arg, setup_logging from enhanced_replica.data_utils import get_ds_meta, load_dataset_manifest, load_dataset_splits, SPLITS from enhanced_replica.io_utils import ( create_run_context, prepare_primary_ds_for_train, save_pred_files, write_run_manifest, write_run_report, write_yaml_minimal, ) from enhanced_replica.metrics import best_threshold_by_f1, binary_metrics from enhanced_replica.model_utils import predict_with_model_payload, save_model_payload, train_transformer_classifier DEFAULT_MODEL_ID = "hfl/chinese-bert-wwm-ext" DEFAULT_EPOCHS = 20 DEFAULT_BATCH_SIZE = 8 DEFAULT_EVAL_BATCH_SIZE = 16 DEFAULT_MAX_LEN = 512 DEFAULT_LR = 2e-5 DEFAULT_WEIGHT_DECAY = 0.03 DEFAULT_WARMUP_RATIO = 0.1 DEFAULT_EARLY_STOPPING_PATIENCE = 3 def run_e01(args: argparse.Namespace) -> dict: setup_logging(args.log_level) logger = logging.getLogger("E01") # 1. Resolve hyper-parameters (CLI overrides script defaults) model_id = resolve_arg(args.model_id, DEFAULT_MODEL_ID) epochs = resolve_arg(args.epochs, DEFAULT_EPOCHS) batch_size = resolve_arg(args.batch_size, DEFAULT_BATCH_SIZE) eval_batch_size = resolve_arg(args.eval_batch_size, DEFAULT_EVAL_BATCH_SIZE) max_len = resolve_arg(args.max_len, DEFAULT_MAX_LEN) lr = resolve_arg(args.lr, DEFAULT_LR) weight_decay = resolve_arg(args.weight_decay, DEFAULT_WEIGHT_DECAY) warmup_ratio = resolve_arg(args.warmup_ratio, DEFAULT_WARMUP_RATIO) grad_acc = resolve_arg(args.grad_acc, 1) early_stopping_patience = resolve_arg(args.early_stopping_patience, DEFAULT_EARLY_STOPPING_PATIENCE) use_amp = False if args.no_amp else (args.use_amp if args.use_amp else True) # 2. Setup run context ctx = create_run_context(eid="E01", output_root=Path(args.output_root), run_name=args.run_name) logger.info(f"E01 BERT baseline start | run_name={ctx.run_name} | smoke={args.smoke}") logger.info(f"Hyper-params: model={model_id}, epochs={epochs}, max_len={max_len}, lr={lr}, bs={batch_size}") # 3. Load dataset manifest and locate DS01 manifest = load_dataset_manifest(Path(args.manifest_file)) ds_meta = get_ds_meta(manifest, "DS01") logger.info(f"Dataset: {ds_meta['dataset_id']} | dir={ds_meta['dataset_dir']}") # 4. Load splits splits = prepare_primary_ds_for_train(ds_meta, smoke=args.smoke, seed=args.seed, merge_dev=args.merge_dev) for sp in SPLITS: logger.info(f" {sp}: {len(splits[sp])} rows") # 5. Train transformer with champion-aligned settings model_payload, dev_scores = train_transformer_classifier( train_df=splits["train"], dev_df=splits["dev"], model_id=model_id, model_output_dir=ctx.run_dir / "model", seed=args.seed, device=args.device, epochs=epochs, batch_size=batch_size, eval_batch_size=eval_batch_size, learning_rate=lr, max_len=max_len, weight_decay=weight_decay, warmup_ratio=warmup_ratio, gradient_accumulation_steps=grad_acc, use_custom_head=True, custom_head_dropout=0.5, custom_head_intermediate=512, save_best=True, early_stopping_patience=early_stopping_patience, use_amp=use_amp, loss_divergence_patience=2, max_restarts=1, ) logger.info("Training complete. Best model saved to hf_model_best/.") # 6. Persist model payload model_payload_path = ctx.run_dir / "model_payload.pkl" save_model_payload(model_payload_path, model_payload) # 7. Predict on all splits pred_splits = {} for sp in SPLITS: scores = predict_with_model_payload(model_payload, splits[sp], device=args.device) df = splits[sp].copy() df["score"] = scores pred_splits[sp] = df # 8. Determine threshold on dev threshold = best_threshold_by_f1( pred_splits["dev"]["label"].to_numpy(), pred_splits["dev"]["score"].to_numpy(), ) logger.info(f"Best dev threshold by F1: {threshold:.4f}") # 9. Apply threshold and compute metrics metrics = {} for sp in SPLITS: pred_splits[sp]["pred"] = (pred_splits[sp]["score"] >= threshold).astype(int) m = binary_metrics(pred_splits[sp]["label"].to_numpy(), pred_splits[sp]["score"].to_numpy(), threshold) metrics[sp] = m logger.info(f"{sp} metrics: accuracy={m['accuracy']:.4f}, f1={m['f1']:.4f}") # 10. Save predictions and metrics save_pred_files(ctx, pred_splits) for sp in SPLITS: pd.DataFrame([metrics[sp]]).to_csv(ctx.run_dir / f"bert_{sp}_metrics.csv", index=False) # 11. Write run manifest and report config = { "model_id": model_id, "epochs": epochs, "batch_size": batch_size, "eval_batch_size": eval_batch_size, "max_len": max_len, "learning_rate": lr, "weight_decay": weight_decay, "warmup_ratio": warmup_ratio, "grad_acc": grad_acc, "early_stopping_patience": early_stopping_patience, "use_amp": use_amp, "threshold": float(threshold), "seed": args.seed, "smoke": args.smoke, "device": args.device, } write_yaml_minimal(ctx.config_file, config) result = { "dev_f1": metrics["dev"]["f1"], "test_f1": metrics["test"]["f1"], "threshold": float(threshold), } write_run_manifest(ctx, status="success", payload=result) write_run_report(ctx, status="success", config=config, payload=result) logger.info("E01 complete.") return result def main() -> int: parser = argparse.ArgumentParser(description="E01 BERT baseline") parser = add_base_args(parser) parser = add_train_args(parser) parser.add_argument("--merge_dev", action="store_true", help="Merge original train+dev and re-split before training.") args = parser.parse_args() try: run_e01(args) return 0 except Exception as e: logging.getLogger("E01").error(f"ERROR: {e}", exc_info=True) raise if __name__ == "__main__": raise SystemExit(main())