louisedrumm commited on
Commit
5b00ca5
·
1 Parent(s): 5b3cb77

Update trainer.py

Browse files
Files changed (1) hide show
  1. trainer.py +40 -0
trainer.py CHANGED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer, TextDataset, DataCollatorForLanguageModeling, Trainer, TrainingArguments
2
+
3
+ # Specify the model and tokenizer
4
+ model_name = "gpt2"
5
+ model = GPT2LMHeadModel.from_pretrained(model_name)
6
+ tokenizer = GPT2Tokenizer.from_pretrained(model_name)
7
+
8
+ # Create a dataset from your text file
9
+ dataset = TextDataset(
10
+ tokenizer=tokenizer,
11
+ file_path="my_text_file.txt",
12
+ block_size=128,
13
+ )
14
+
15
+ # Create a data collator
16
+ data_collator = DataCollatorForLanguageModeling(
17
+ tokenizer=tokenizer,
18
+ mlm=False,
19
+ )
20
+
21
+ # Specify the training arguments
22
+ training_args = TrainingArguments(
23
+ output_dir="./results",
24
+ overwrite_output_dir=True,
25
+ num_train_epochs=3,
26
+ per_device_train_batch_size=1,
27
+ save_steps=10_000,
28
+ save_total_limit=2,
29
+ )
30
+
31
+ # Create a trainer
32
+ trainer = Trainer(
33
+ model=model,
34
+ args=training_args,
35
+ data_collator=data_collator,
36
+ train_dataset=dataset,
37
+ )
38
+
39
+ # Train the model
40
+ trainer.train()