File size: 1,627 Bytes
e323466
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
"""Fine-tune DistilBERT for intent classification and evaluate on the test set."""

import sys
from pathlib import Path

import yaml
from loguru import logger

sys.path.insert(0, str(Path(__file__).resolve().parent.parent))

from src.data.dataset import load_splits
from src.data.preprocessing import set_global_seeds
from src.models.intent_classifier import train, evaluate


def main() -> None:
    """Run DistilBERT fine-tuning and evaluation pipeline."""
    Path("logs").mkdir(exist_ok=True)
    logger.add("logs/train_classifier.log", rotation="10 MB")

    with open("config/config.yaml") as f:
        cfg = yaml.safe_load(f)

    set_global_seeds(cfg["classifier"]["seed"])

    processed_dir = cfg["paths"]["data_processed"]
    train_df, val_df, test_df = load_splits(processed_dir)

    trainer = train(
        train_df=train_df,
        val_df=val_df,
        cfg=cfg,
        save_dir=cfg["paths"]["models_distilbert"],
    )

    model_dir = str(Path(cfg["paths"]["models_distilbert"]) / "best")
    report = evaluate(
        model_dir=model_dir,
        test_df=test_df,
        results_dir=cfg["paths"]["results"],
        batch_size=cfg["classifier"]["batch_size"] * 2,
        max_length=cfg["classifier"]["max_length"],
    )

    weighted_f1 = report["weighted avg"]["f1-score"]
    logger.info(f"DistilBERT complete. Test weighted F1: {weighted_f1:.4f}")

    if weighted_f1 < 0.89:
        logger.warning(
            f"Weighted F1 {weighted_f1:.4f} is below target of 0.89. "
            "Consider tuning hyperparameters or training for more epochs."
        )


if __name__ == "__main__":
    main()