"""Fine-tune DistilBERT for intent classification and evaluate on the test set.""" import sys from pathlib import Path import yaml from loguru import logger sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) from src.data.dataset import load_splits from src.data.preprocessing import set_global_seeds from src.models.intent_classifier import train, evaluate def main() -> None: """Run DistilBERT fine-tuning and evaluation pipeline.""" Path("logs").mkdir(exist_ok=True) logger.add("logs/train_classifier.log", rotation="10 MB") with open("config/config.yaml") as f: cfg = yaml.safe_load(f) set_global_seeds(cfg["classifier"]["seed"]) processed_dir = cfg["paths"]["data_processed"] train_df, val_df, test_df = load_splits(processed_dir) trainer = train( train_df=train_df, val_df=val_df, cfg=cfg, save_dir=cfg["paths"]["models_distilbert"], ) model_dir = str(Path(cfg["paths"]["models_distilbert"]) / "best") report = evaluate( model_dir=model_dir, test_df=test_df, results_dir=cfg["paths"]["results"], batch_size=cfg["classifier"]["batch_size"] * 2, max_length=cfg["classifier"]["max_length"], ) weighted_f1 = report["weighted avg"]["f1-score"] logger.info(f"DistilBERT complete. Test weighted F1: {weighted_f1:.4f}") if weighted_f1 < 0.89: logger.warning( f"Weighted F1 {weighted_f1:.4f} is below target of 0.89. " "Consider tuning hyperparameters or training for more epochs." ) if __name__ == "__main__": main()