File size: 6,993 Bytes
4a0f6a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""E01: BERT baseline training and evaluation.

This script is fully self-contained. It loads DS01, trains a BERT classifier
with a custom MLP head (champion-aligned), evaluates on dev/test, and writes
predictions + metrics to the run directory.
"""
from __future__ import annotations

import argparse
import logging
import sys
from pathlib import Path

import numpy as np
import pandas as pd

# Allow imports from source root package
REPO_ROOT = Path(__file__).resolve()
while REPO_ROOT != REPO_ROOT.parent and not (REPO_ROOT / "src").exists():
    REPO_ROOT = REPO_ROOT.parent
for _candidate in (REPO_ROOT, REPO_ROOT / "src"):
    _candidate_str = str(_candidate)
    if _candidate.exists() and _candidate_str not in sys.path:
        sys.path.insert(0, _candidate_str)

from enhanced_replica.cli_args import add_base_args, add_train_args, resolve_arg, setup_logging
from enhanced_replica.data_utils import get_ds_meta, load_dataset_manifest, load_dataset_splits, SPLITS
from enhanced_replica.io_utils import (
    create_run_context,
    prepare_primary_ds_for_train,
    save_pred_files,
    write_run_manifest,
    write_run_report,
    write_yaml_minimal,
)
from enhanced_replica.metrics import best_threshold_by_f1, binary_metrics
from enhanced_replica.model_utils import predict_with_model_payload, save_model_payload, train_transformer_classifier

DEFAULT_MODEL_ID = "hfl/chinese-bert-wwm-ext"
DEFAULT_EPOCHS = 20
DEFAULT_BATCH_SIZE = 8
DEFAULT_EVAL_BATCH_SIZE = 16
DEFAULT_MAX_LEN = 512
DEFAULT_LR = 2e-5
DEFAULT_WEIGHT_DECAY = 0.03
DEFAULT_WARMUP_RATIO = 0.1
DEFAULT_EARLY_STOPPING_PATIENCE = 3


def run_e01(args: argparse.Namespace) -> dict:
    setup_logging(args.log_level)
    logger = logging.getLogger("E01")

    # 1. Resolve hyper-parameters (CLI overrides script defaults)
    model_id = resolve_arg(args.model_id, DEFAULT_MODEL_ID)
    epochs = resolve_arg(args.epochs, DEFAULT_EPOCHS)
    batch_size = resolve_arg(args.batch_size, DEFAULT_BATCH_SIZE)
    eval_batch_size = resolve_arg(args.eval_batch_size, DEFAULT_EVAL_BATCH_SIZE)
    max_len = resolve_arg(args.max_len, DEFAULT_MAX_LEN)
    lr = resolve_arg(args.lr, DEFAULT_LR)
    weight_decay = resolve_arg(args.weight_decay, DEFAULT_WEIGHT_DECAY)
    warmup_ratio = resolve_arg(args.warmup_ratio, DEFAULT_WARMUP_RATIO)
    grad_acc = resolve_arg(args.grad_acc, 1)
    early_stopping_patience = resolve_arg(args.early_stopping_patience, DEFAULT_EARLY_STOPPING_PATIENCE)
    use_amp = False if args.no_amp else (args.use_amp if args.use_amp else True)

    # 2. Setup run context
    ctx = create_run_context(eid="E01", output_root=Path(args.output_root), run_name=args.run_name)
    logger.info(f"E01 BERT baseline start | run_name={ctx.run_name} | smoke={args.smoke}")
    logger.info(f"Hyper-params: model={model_id}, epochs={epochs}, max_len={max_len}, lr={lr}, bs={batch_size}")

    # 3. Load dataset manifest and locate DS01
    manifest = load_dataset_manifest(Path(args.manifest_file))
    ds_meta = get_ds_meta(manifest, "DS01")
    logger.info(f"Dataset: {ds_meta['dataset_id']} | dir={ds_meta['dataset_dir']}")

    # 4. Load splits
    splits = prepare_primary_ds_for_train(ds_meta, smoke=args.smoke, seed=args.seed, merge_dev=args.merge_dev)
    for sp in SPLITS:
        logger.info(f"  {sp}: {len(splits[sp])} rows")

    # 5. Train transformer with champion-aligned settings
    model_payload, dev_scores = train_transformer_classifier(
        train_df=splits["train"],
        dev_df=splits["dev"],
        model_id=model_id,
        model_output_dir=ctx.run_dir / "model",
        seed=args.seed,
        device=args.device,
        epochs=epochs,
        batch_size=batch_size,
        eval_batch_size=eval_batch_size,
        learning_rate=lr,
        max_len=max_len,
        weight_decay=weight_decay,
        warmup_ratio=warmup_ratio,
        gradient_accumulation_steps=grad_acc,
        use_custom_head=True,
        custom_head_dropout=0.5,
        custom_head_intermediate=512,
        save_best=True,
        early_stopping_patience=early_stopping_patience,
        use_amp=use_amp,
        loss_divergence_patience=2,
        max_restarts=1,
    )
    logger.info("Training complete. Best model saved to hf_model_best/.")

    # 6. Persist model payload
    model_payload_path = ctx.run_dir / "model_payload.pkl"
    save_model_payload(model_payload_path, model_payload)

    # 7. Predict on all splits
    pred_splits = {}
    for sp in SPLITS:
        scores = predict_with_model_payload(model_payload, splits[sp], device=args.device)
        df = splits[sp].copy()
        df["score"] = scores
        pred_splits[sp] = df

    # 8. Determine threshold on dev
    threshold = best_threshold_by_f1(
        pred_splits["dev"]["label"].to_numpy(),
        pred_splits["dev"]["score"].to_numpy(),
    )
    logger.info(f"Best dev threshold by F1: {threshold:.4f}")

    # 9. Apply threshold and compute metrics
    metrics = {}
    for sp in SPLITS:
        pred_splits[sp]["pred"] = (pred_splits[sp]["score"] >= threshold).astype(int)
        m = binary_metrics(pred_splits[sp]["label"].to_numpy(), pred_splits[sp]["score"].to_numpy(), threshold)
        metrics[sp] = m
        logger.info(f"{sp} metrics: accuracy={m['accuracy']:.4f}, f1={m['f1']:.4f}")

    # 10. Save predictions and metrics
    save_pred_files(ctx, pred_splits)
    for sp in SPLITS:
        pd.DataFrame([metrics[sp]]).to_csv(ctx.run_dir / f"bert_{sp}_metrics.csv", index=False)

    # 11. Write run manifest and report
    config = {
        "model_id": model_id,
        "epochs": epochs,
        "batch_size": batch_size,
        "eval_batch_size": eval_batch_size,
        "max_len": max_len,
        "learning_rate": lr,
        "weight_decay": weight_decay,
        "warmup_ratio": warmup_ratio,
        "grad_acc": grad_acc,
        "early_stopping_patience": early_stopping_patience,
        "use_amp": use_amp,
        "threshold": float(threshold),
        "seed": args.seed,
        "smoke": args.smoke,
        "device": args.device,
    }
    write_yaml_minimal(ctx.config_file, config)

    result = {
        "dev_f1": metrics["dev"]["f1"],
        "test_f1": metrics["test"]["f1"],
        "threshold": float(threshold),
    }
    write_run_manifest(ctx, status="success", payload=result)
    write_run_report(ctx, status="success", config=config, payload=result)
    logger.info("E01 complete.")
    return result


def main() -> int:
    parser = argparse.ArgumentParser(description="E01 BERT baseline")
    parser = add_base_args(parser)
    parser = add_train_args(parser)
    parser.add_argument("--merge_dev", action="store_true", help="Merge original train+dev and re-split before training.")
    args = parser.parse_args()
    try:
        run_e01(args)
        return 0
    except Exception as e:
        logging.getLogger("E01").error(f"ERROR: {e}", exc_info=True)
        raise


if __name__ == "__main__":
    raise SystemExit(main())