Spaces:
Running
Running
File size: 1,627 Bytes
e323466 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 | """Fine-tune DistilBERT for intent classification and evaluate on the test set."""
import sys
from pathlib import Path
import yaml
from loguru import logger
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from src.data.dataset import load_splits
from src.data.preprocessing import set_global_seeds
from src.models.intent_classifier import train, evaluate
def main() -> None:
"""Run DistilBERT fine-tuning and evaluation pipeline."""
Path("logs").mkdir(exist_ok=True)
logger.add("logs/train_classifier.log", rotation="10 MB")
with open("config/config.yaml") as f:
cfg = yaml.safe_load(f)
set_global_seeds(cfg["classifier"]["seed"])
processed_dir = cfg["paths"]["data_processed"]
train_df, val_df, test_df = load_splits(processed_dir)
trainer = train(
train_df=train_df,
val_df=val_df,
cfg=cfg,
save_dir=cfg["paths"]["models_distilbert"],
)
model_dir = str(Path(cfg["paths"]["models_distilbert"]) / "best")
report = evaluate(
model_dir=model_dir,
test_df=test_df,
results_dir=cfg["paths"]["results"],
batch_size=cfg["classifier"]["batch_size"] * 2,
max_length=cfg["classifier"]["max_length"],
)
weighted_f1 = report["weighted avg"]["f1-score"]
logger.info(f"DistilBERT complete. Test weighted F1: {weighted_f1:.4f}")
if weighted_f1 < 0.89:
logger.warning(
f"Weighted F1 {weighted_f1:.4f} is below target of 0.89. "
"Consider tuning hyperparameters or training for more epochs."
)
if __name__ == "__main__":
main()
|