Text Classification
Transformers
Safetensors
Chinese
chinese
ai-text-detection
ensemble
bert
roberta
qwen
lora
research
dataset
Instructions to use LUCIFerace/enhanced-replica-model-pack with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use LUCIFerace/enhanced-replica-model-pack with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-classification", model="LUCIFerace/enhanced-replica-model-pack")# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("LUCIFerace/enhanced-replica-model-pack", dtype="auto") - Notebooks
- Google Colab
- Kaggle
File size: 6,993 Bytes
4a0f6a5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 | #!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""E01: BERT baseline training and evaluation.
This script is fully self-contained. It loads DS01, trains a BERT classifier
with a custom MLP head (champion-aligned), evaluates on dev/test, and writes
predictions + metrics to the run directory.
"""
from __future__ import annotations
import argparse
import logging
import sys
from pathlib import Path
import numpy as np
import pandas as pd
# Allow imports from source root package
REPO_ROOT = Path(__file__).resolve()
while REPO_ROOT != REPO_ROOT.parent and not (REPO_ROOT / "src").exists():
REPO_ROOT = REPO_ROOT.parent
for _candidate in (REPO_ROOT, REPO_ROOT / "src"):
_candidate_str = str(_candidate)
if _candidate.exists() and _candidate_str not in sys.path:
sys.path.insert(0, _candidate_str)
from enhanced_replica.cli_args import add_base_args, add_train_args, resolve_arg, setup_logging
from enhanced_replica.data_utils import get_ds_meta, load_dataset_manifest, load_dataset_splits, SPLITS
from enhanced_replica.io_utils import (
create_run_context,
prepare_primary_ds_for_train,
save_pred_files,
write_run_manifest,
write_run_report,
write_yaml_minimal,
)
from enhanced_replica.metrics import best_threshold_by_f1, binary_metrics
from enhanced_replica.model_utils import predict_with_model_payload, save_model_payload, train_transformer_classifier
DEFAULT_MODEL_ID = "hfl/chinese-bert-wwm-ext"
DEFAULT_EPOCHS = 20
DEFAULT_BATCH_SIZE = 8
DEFAULT_EVAL_BATCH_SIZE = 16
DEFAULT_MAX_LEN = 512
DEFAULT_LR = 2e-5
DEFAULT_WEIGHT_DECAY = 0.03
DEFAULT_WARMUP_RATIO = 0.1
DEFAULT_EARLY_STOPPING_PATIENCE = 3
def run_e01(args: argparse.Namespace) -> dict:
setup_logging(args.log_level)
logger = logging.getLogger("E01")
# 1. Resolve hyper-parameters (CLI overrides script defaults)
model_id = resolve_arg(args.model_id, DEFAULT_MODEL_ID)
epochs = resolve_arg(args.epochs, DEFAULT_EPOCHS)
batch_size = resolve_arg(args.batch_size, DEFAULT_BATCH_SIZE)
eval_batch_size = resolve_arg(args.eval_batch_size, DEFAULT_EVAL_BATCH_SIZE)
max_len = resolve_arg(args.max_len, DEFAULT_MAX_LEN)
lr = resolve_arg(args.lr, DEFAULT_LR)
weight_decay = resolve_arg(args.weight_decay, DEFAULT_WEIGHT_DECAY)
warmup_ratio = resolve_arg(args.warmup_ratio, DEFAULT_WARMUP_RATIO)
grad_acc = resolve_arg(args.grad_acc, 1)
early_stopping_patience = resolve_arg(args.early_stopping_patience, DEFAULT_EARLY_STOPPING_PATIENCE)
use_amp = False if args.no_amp else (args.use_amp if args.use_amp else True)
# 2. Setup run context
ctx = create_run_context(eid="E01", output_root=Path(args.output_root), run_name=args.run_name)
logger.info(f"E01 BERT baseline start | run_name={ctx.run_name} | smoke={args.smoke}")
logger.info(f"Hyper-params: model={model_id}, epochs={epochs}, max_len={max_len}, lr={lr}, bs={batch_size}")
# 3. Load dataset manifest and locate DS01
manifest = load_dataset_manifest(Path(args.manifest_file))
ds_meta = get_ds_meta(manifest, "DS01")
logger.info(f"Dataset: {ds_meta['dataset_id']} | dir={ds_meta['dataset_dir']}")
# 4. Load splits
splits = prepare_primary_ds_for_train(ds_meta, smoke=args.smoke, seed=args.seed, merge_dev=args.merge_dev)
for sp in SPLITS:
logger.info(f" {sp}: {len(splits[sp])} rows")
# 5. Train transformer with champion-aligned settings
model_payload, dev_scores = train_transformer_classifier(
train_df=splits["train"],
dev_df=splits["dev"],
model_id=model_id,
model_output_dir=ctx.run_dir / "model",
seed=args.seed,
device=args.device,
epochs=epochs,
batch_size=batch_size,
eval_batch_size=eval_batch_size,
learning_rate=lr,
max_len=max_len,
weight_decay=weight_decay,
warmup_ratio=warmup_ratio,
gradient_accumulation_steps=grad_acc,
use_custom_head=True,
custom_head_dropout=0.5,
custom_head_intermediate=512,
save_best=True,
early_stopping_patience=early_stopping_patience,
use_amp=use_amp,
loss_divergence_patience=2,
max_restarts=1,
)
logger.info("Training complete. Best model saved to hf_model_best/.")
# 6. Persist model payload
model_payload_path = ctx.run_dir / "model_payload.pkl"
save_model_payload(model_payload_path, model_payload)
# 7. Predict on all splits
pred_splits = {}
for sp in SPLITS:
scores = predict_with_model_payload(model_payload, splits[sp], device=args.device)
df = splits[sp].copy()
df["score"] = scores
pred_splits[sp] = df
# 8. Determine threshold on dev
threshold = best_threshold_by_f1(
pred_splits["dev"]["label"].to_numpy(),
pred_splits["dev"]["score"].to_numpy(),
)
logger.info(f"Best dev threshold by F1: {threshold:.4f}")
# 9. Apply threshold and compute metrics
metrics = {}
for sp in SPLITS:
pred_splits[sp]["pred"] = (pred_splits[sp]["score"] >= threshold).astype(int)
m = binary_metrics(pred_splits[sp]["label"].to_numpy(), pred_splits[sp]["score"].to_numpy(), threshold)
metrics[sp] = m
logger.info(f"{sp} metrics: accuracy={m['accuracy']:.4f}, f1={m['f1']:.4f}")
# 10. Save predictions and metrics
save_pred_files(ctx, pred_splits)
for sp in SPLITS:
pd.DataFrame([metrics[sp]]).to_csv(ctx.run_dir / f"bert_{sp}_metrics.csv", index=False)
# 11. Write run manifest and report
config = {
"model_id": model_id,
"epochs": epochs,
"batch_size": batch_size,
"eval_batch_size": eval_batch_size,
"max_len": max_len,
"learning_rate": lr,
"weight_decay": weight_decay,
"warmup_ratio": warmup_ratio,
"grad_acc": grad_acc,
"early_stopping_patience": early_stopping_patience,
"use_amp": use_amp,
"threshold": float(threshold),
"seed": args.seed,
"smoke": args.smoke,
"device": args.device,
}
write_yaml_minimal(ctx.config_file, config)
result = {
"dev_f1": metrics["dev"]["f1"],
"test_f1": metrics["test"]["f1"],
"threshold": float(threshold),
}
write_run_manifest(ctx, status="success", payload=result)
write_run_report(ctx, status="success", config=config, payload=result)
logger.info("E01 complete.")
return result
def main() -> int:
parser = argparse.ArgumentParser(description="E01 BERT baseline")
parser = add_base_args(parser)
parser = add_train_args(parser)
parser.add_argument("--merge_dev", action="store_true", help="Merge original train+dev and re-split before training.")
args = parser.parse_args()
try:
run_e01(args)
return 0
except Exception as e:
logging.getLogger("E01").error(f"ERROR: {e}", exc_info=True)
raise
if __name__ == "__main__":
raise SystemExit(main())
|