| import os |
| import yaml |
| import numpy as np |
| import mlflow |
| import mlflow.pytorch |
| from pathlib import Path |
| from dotenv import load_dotenv |
| from transformers import ( |
| DistilBertForSequenceClassification, |
| TrainingArguments, |
| Trainer, |
| EarlyStoppingCallback, |
| ) |
| from sklearn.metrics import accuracy_score, f1_score |
|
|
| from src.training.dataset import load_ag_news, ID2LABEL, LABEL2ID, NUM_LABELS |
|
|
| load_dotenv() |
|
|
|
|
| |
| |
| |
| def compute_metrics(eval_pred): |
| logits, labels = eval_pred |
| predictions = np.argmax(logits, axis=-1) |
| return { |
| "accuracy": accuracy_score(labels, predictions), |
| "f1": f1_score(labels, predictions, average="weighted"), |
| } |
|
|
|
|
| |
| def train(config: dict, subset: int = None): |
| |
| mlflow.set_tracking_uri(os.environ["MLFLOW_TRACKING_URI"]) |
| mlflow.set_experiment(config["mlflow"]["experiment_name"]) |
|
|
| |
| print("Loading dataset...") |
| dataset, tokenizer = load_ag_news( |
| tokenizer_name=config["model"]["base"], |
| max_length=config["model"]["max_length"], |
| ) |
|
|
| |
| if subset: |
| dataset["train"] = dataset["train"].select(range(subset)) |
| dataset["test"] = dataset["test"].select(range(subset // 10)) |
| print(f"Subset mode: using {subset} train, {subset // 10} test examples") |
|
|
| |
| print("Loading model...") |
| model = DistilBertForSequenceClassification.from_pretrained( |
| config["model"]["base"], |
| num_labels=NUM_LABELS, |
| id2label=ID2LABEL, |
| label2id=LABEL2ID, |
| ) |
|
|
| tc = config["training"] |
|
|
| args = TrainingArguments( |
| output_dir="artifacts/checkpoints", |
| num_train_epochs=tc["epochs"], |
| per_device_train_batch_size=tc["batch_size"], |
| per_device_eval_batch_size=64, |
| learning_rate=float(tc["learning_rate"]), |
| warmup_steps=tc["warmup_steps"], |
| weight_decay=tc["weight_decay"], |
| eval_strategy="epoch", |
| save_strategy="epoch", |
| load_best_model_at_end=True, |
| metric_for_best_model="accuracy", |
| greater_is_better=True, |
| logging_steps=100, |
| report_to="none", |
| ) |
|
|
| trainer = Trainer( |
| model=model, |
| args=args, |
| train_dataset=dataset["train"], |
| eval_dataset=dataset["test"], |
| compute_metrics=compute_metrics, |
| callbacks=[EarlyStoppingCallback( |
| early_stopping_patience=tc["early_stopping_patience"] |
| )], |
| ) |
|
|
| |
| with mlflow.start_run(): |
| |
| mlflow.log_params({ |
| "base_model": config["model"]["base"], |
| "max_length": config["model"]["max_length"], |
| "epochs": tc["epochs"], |
| "batch_size": tc["batch_size"], |
| "learning_rate": tc["learning_rate"], |
| "warmup_steps": tc["warmup_steps"], |
| "weight_decay": tc["weight_decay"], |
| "subset": subset or "full", |
| }) |
|
|
| print("Training...") |
| trainer.train() |
|
|
| |
| print("Evaluating...") |
| metrics = trainer.evaluate() |
| print(metrics) |
|
|
| |
| mlflow.log_metrics({ |
| "accuracy": metrics["eval_accuracy"], |
| "f1": metrics["eval_f1"], |
| "loss": metrics["eval_loss"], |
| }) |
|
|
| print(f"\nFinal accuracy : {metrics['eval_accuracy']:.4f}") |
| print(f"Final F1 : {metrics['eval_f1']:.4f}") |
|
|
|
|
| if __name__ == "__main__": |
| config_path = Path(__file__).parent.parent.parent / "configs" / "training_config.yaml" |
| with open(config_path) as f: |
| config = yaml.safe_load(f) |
|
|
| train(config, subset=None) |
|
|