LUCIFerace's picture
Add files using upload-large-folder tool
4a0f6a5 verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""E01: BERT baseline training and evaluation.
This script is fully self-contained. It loads DS01, trains a BERT classifier
with a custom MLP head (champion-aligned), evaluates on dev/test, and writes
predictions + metrics to the run directory.
"""
from __future__ import annotations
import argparse
import logging
import sys
from pathlib import Path
import numpy as np
import pandas as pd
# Allow imports from source root package
REPO_ROOT = Path(__file__).resolve()
while REPO_ROOT != REPO_ROOT.parent and not (REPO_ROOT / "src").exists():
REPO_ROOT = REPO_ROOT.parent
for _candidate in (REPO_ROOT, REPO_ROOT / "src"):
_candidate_str = str(_candidate)
if _candidate.exists() and _candidate_str not in sys.path:
sys.path.insert(0, _candidate_str)
from enhanced_replica.cli_args import add_base_args, add_train_args, resolve_arg, setup_logging
from enhanced_replica.data_utils import get_ds_meta, load_dataset_manifest, load_dataset_splits, SPLITS
from enhanced_replica.io_utils import (
create_run_context,
prepare_primary_ds_for_train,
save_pred_files,
write_run_manifest,
write_run_report,
write_yaml_minimal,
)
from enhanced_replica.metrics import best_threshold_by_f1, binary_metrics
from enhanced_replica.model_utils import predict_with_model_payload, save_model_payload, train_transformer_classifier
DEFAULT_MODEL_ID = "hfl/chinese-bert-wwm-ext"
DEFAULT_EPOCHS = 20
DEFAULT_BATCH_SIZE = 8
DEFAULT_EVAL_BATCH_SIZE = 16
DEFAULT_MAX_LEN = 512
DEFAULT_LR = 2e-5
DEFAULT_WEIGHT_DECAY = 0.03
DEFAULT_WARMUP_RATIO = 0.1
DEFAULT_EARLY_STOPPING_PATIENCE = 3
def run_e01(args: argparse.Namespace) -> dict:
setup_logging(args.log_level)
logger = logging.getLogger("E01")
# 1. Resolve hyper-parameters (CLI overrides script defaults)
model_id = resolve_arg(args.model_id, DEFAULT_MODEL_ID)
epochs = resolve_arg(args.epochs, DEFAULT_EPOCHS)
batch_size = resolve_arg(args.batch_size, DEFAULT_BATCH_SIZE)
eval_batch_size = resolve_arg(args.eval_batch_size, DEFAULT_EVAL_BATCH_SIZE)
max_len = resolve_arg(args.max_len, DEFAULT_MAX_LEN)
lr = resolve_arg(args.lr, DEFAULT_LR)
weight_decay = resolve_arg(args.weight_decay, DEFAULT_WEIGHT_DECAY)
warmup_ratio = resolve_arg(args.warmup_ratio, DEFAULT_WARMUP_RATIO)
grad_acc = resolve_arg(args.grad_acc, 1)
early_stopping_patience = resolve_arg(args.early_stopping_patience, DEFAULT_EARLY_STOPPING_PATIENCE)
use_amp = False if args.no_amp else (args.use_amp if args.use_amp else True)
# 2. Setup run context
ctx = create_run_context(eid="E01", output_root=Path(args.output_root), run_name=args.run_name)
logger.info(f"E01 BERT baseline start | run_name={ctx.run_name} | smoke={args.smoke}")
logger.info(f"Hyper-params: model={model_id}, epochs={epochs}, max_len={max_len}, lr={lr}, bs={batch_size}")
# 3. Load dataset manifest and locate DS01
manifest = load_dataset_manifest(Path(args.manifest_file))
ds_meta = get_ds_meta(manifest, "DS01")
logger.info(f"Dataset: {ds_meta['dataset_id']} | dir={ds_meta['dataset_dir']}")
# 4. Load splits
splits = prepare_primary_ds_for_train(ds_meta, smoke=args.smoke, seed=args.seed, merge_dev=args.merge_dev)
for sp in SPLITS:
logger.info(f" {sp}: {len(splits[sp])} rows")
# 5. Train transformer with champion-aligned settings
model_payload, dev_scores = train_transformer_classifier(
train_df=splits["train"],
dev_df=splits["dev"],
model_id=model_id,
model_output_dir=ctx.run_dir / "model",
seed=args.seed,
device=args.device,
epochs=epochs,
batch_size=batch_size,
eval_batch_size=eval_batch_size,
learning_rate=lr,
max_len=max_len,
weight_decay=weight_decay,
warmup_ratio=warmup_ratio,
gradient_accumulation_steps=grad_acc,
use_custom_head=True,
custom_head_dropout=0.5,
custom_head_intermediate=512,
save_best=True,
early_stopping_patience=early_stopping_patience,
use_amp=use_amp,
loss_divergence_patience=2,
max_restarts=1,
)
logger.info("Training complete. Best model saved to hf_model_best/.")
# 6. Persist model payload
model_payload_path = ctx.run_dir / "model_payload.pkl"
save_model_payload(model_payload_path, model_payload)
# 7. Predict on all splits
pred_splits = {}
for sp in SPLITS:
scores = predict_with_model_payload(model_payload, splits[sp], device=args.device)
df = splits[sp].copy()
df["score"] = scores
pred_splits[sp] = df
# 8. Determine threshold on dev
threshold = best_threshold_by_f1(
pred_splits["dev"]["label"].to_numpy(),
pred_splits["dev"]["score"].to_numpy(),
)
logger.info(f"Best dev threshold by F1: {threshold:.4f}")
# 9. Apply threshold and compute metrics
metrics = {}
for sp in SPLITS:
pred_splits[sp]["pred"] = (pred_splits[sp]["score"] >= threshold).astype(int)
m = binary_metrics(pred_splits[sp]["label"].to_numpy(), pred_splits[sp]["score"].to_numpy(), threshold)
metrics[sp] = m
logger.info(f"{sp} metrics: accuracy={m['accuracy']:.4f}, f1={m['f1']:.4f}")
# 10. Save predictions and metrics
save_pred_files(ctx, pred_splits)
for sp in SPLITS:
pd.DataFrame([metrics[sp]]).to_csv(ctx.run_dir / f"bert_{sp}_metrics.csv", index=False)
# 11. Write run manifest and report
config = {
"model_id": model_id,
"epochs": epochs,
"batch_size": batch_size,
"eval_batch_size": eval_batch_size,
"max_len": max_len,
"learning_rate": lr,
"weight_decay": weight_decay,
"warmup_ratio": warmup_ratio,
"grad_acc": grad_acc,
"early_stopping_patience": early_stopping_patience,
"use_amp": use_amp,
"threshold": float(threshold),
"seed": args.seed,
"smoke": args.smoke,
"device": args.device,
}
write_yaml_minimal(ctx.config_file, config)
result = {
"dev_f1": metrics["dev"]["f1"],
"test_f1": metrics["test"]["f1"],
"threshold": float(threshold),
}
write_run_manifest(ctx, status="success", payload=result)
write_run_report(ctx, status="success", config=config, payload=result)
logger.info("E01 complete.")
return result
def main() -> int:
parser = argparse.ArgumentParser(description="E01 BERT baseline")
parser = add_base_args(parser)
parser = add_train_args(parser)
parser.add_argument("--merge_dev", action="store_true", help="Merge original train+dev and re-split before training.")
args = parser.parse_args()
try:
run_e01(args)
return 0
except Exception as e:
logging.getLogger("E01").error(f"ERROR: {e}", exc_info=True)
raise
if __name__ == "__main__":
raise SystemExit(main())