viswanani's picture
Create src/train.py
69914b2 verified
import os
import mlflow
import yaml
from transformers import (
Trainer,
TrainingArguments,
AutoModelForSequenceClassification,
AutoTokenizer
)
from datasets import load_dataset
from sklearn.metrics import accuracy_score
import torch
def compute_metrics(eval_pred):
logits, labels = eval_pred
preds = torch.argmax(torch.tensor(logits), axis=1)
acc = accuracy_score(labels, preds)
return {"accuracy": acc}
def load_config():
with open("configs/training_config.yaml") as f:
return yaml.safe_load(f)
def main():
cfg = load_config()
mlflow.set_experiment("huggingface-fulltrack-clone")
with mlflow.start_run():
tokenizer = AutoTokenizer.from_pretrained(cfg["model_name"])
model = AutoModelForSequenceClassification.from_pretrained(cfg["model_name"], num_labels=2)
dataset = load_dataset("imdb").map(lambda e: tokenizer(e["text"], truncation=True, padding="max_length"), batched=True)
dataset = dataset.rename_column("label", "labels").with_format("torch")
training_args = TrainingArguments(**cfg)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset["train"].shuffle(seed=42).select(range(2000)),
eval_dataset=dataset["test"].select(range(1000)),
tokenizer=tokenizer,
compute_metrics=compute_metrics,
)
trainer.train()
trainer.evaluate()
mlflow.log_params(cfg)
mlflow.pytorch.log_model(model, artifact_path="model")
if __name__ == "__main__":
main()