Proooof commited on
Commit
5a6b9dd
·
verified ·
1 Parent(s): 162ca90

Create training/train_sentiment.py

Browse files
Files changed (1) hide show
  1. training/train_sentiment.py +61 -0
training/train_sentiment.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse, pandas as pd
2
+ from datasets import Dataset
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
4
+ from training.utils import compute_metrics_sentiment
5
+
6
+ parser = argparse.ArgumentParser()
7
+ parser.add_argument("--model_name", default="distilbert-base-uncased")
8
+ parser.add_argument("--train_csv", required=True)
9
+ parser.add_argument("--eval_csv", required=True)
10
+ parser.add_argument("--text_col", default="text")
11
+ parser.add_argument("--label_col", default="label")
12
+ parser.add_argument("--output_dir", default="./outputs/sentiment")
13
+ parser.add_argument("--epochs", type=int, default=3)
14
+ parser.add_argument("--batch_size", type=int, default=16)
15
+ parser.add_argument("--lr", type=float, default=5e-5)
16
+ args = parser.parse_args()
17
+
18
+ train_df = pd.read_csv(args.train_csv)
19
+ eval_df = pd.read_csv(args.eval_csv)
20
+
21
+ label_names = sorted(train_df[args.label_col].unique().tolist())
22
+ label2id = {l:i for i,l in enumerate(label_names)}
23
+ id2label = {i:l for l,i in label2id.items()}
24
+
25
+ def encode(df):
26
+ tok = tokenizer(df[args.text_col].tolist(), truncation=True, padding=True)
27
+ tok["labels"] = [label2id[l] for l in df[args.label_col].tolist()]
28
+ return tok
29
+
30
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
31
+ train_ds = Dataset.from_pandas(train_df).map(encode, batched=True, remove_columns=train_df.columns)
32
+ eval_ds = Dataset.from_pandas(eval_df).map(encode, batched=True, remove_columns=eval_df.columns)
33
+
34
+ model = AutoModelForSequenceClassification.from_pretrained(
35
+ args.model_name, num_labels=len(label_names), id2label=id2label, label2id=label2id
36
+ )
37
+
38
+ training_args = TrainingArguments(
39
+ output_dir=args.output_dir,
40
+ evaluation_strategy="epoch",
41
+ learning_rate=args.lr,
42
+ per_device_train_batch_size=args.batch_size,
43
+ per_device_eval_batch_size=args.batch_size,
44
+ num_train_epochs=args.epochs,
45
+ weight_decay=0.01,
46
+ load_best_model_at_end=True,
47
+ metric_for_best_model="accuracy",
48
+ )
49
+
50
+ trainer = Trainer(
51
+ model=model,
52
+ args=training_args,
53
+ train_dataset=train_ds,
54
+ eval_dataset=eval_ds,
55
+ tokenizer=tokenizer,
56
+ compute_metrics=compute_metrics_sentiment,
57
+ )
58
+
59
+ trainer.train()
60
+ trainer.save_model(args.output_dir)
61
+ tokenizer.save_pretrained(args.output_dir)