pangxiang commited on
Commit
bee18d3
·
verified ·
1 Parent(s): a189a3a

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +29 -0
train.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
2
+ from datasets import Dataset
3
+ import json
4
+
5
+ with open('train_data.json', 'r') as f:
6
+ data = json.load(f)
7
+
8
+ texts = []
9
+ labels = []
10
+ for label, samples in data.items():
11
+ for text in samples:
12
+ texts.append(text)
13
+ labels.append(label)
14
+
15
+ dataset = Dataset.from_dict({"text": texts, "label": labels})
16
+
17
+ model_name = "distilbert-base-uncased"
18
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
19
+ model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3)
20
+
21
+ def tokenize_function(examples):
22
+ return tokenizer(examples["text"], padding="max_length", truncation=True)
23
+
24
+ tokenized_datasets = dataset.map(tokenize_function, batched=True)
25
+
26
+ training_args = TrainingArguments(output_dir="./results", num_train_epochs=2)
27
+ trainer = Trainer(model=model, args=training_args, train_dataset=tokenized_datasets)
28
+ trainer.train()
29
+ model.save_pretrained("./trained_model")