lap096 commited on
Commit
c40dcd1
·
verified ·
1 Parent(s): 6e58093

Create agent_1_train.py

Browse files
Files changed (1) hide show
  1. agent_1_train.py +62 -0
agent_1_train.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # agent_1_train.py
2
+ from datasets import load_dataset
3
+ from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer
4
+ from transformers import DataCollatorForLanguageModeling, Trainer, TrainingArguments
5
+
6
+ # Load dataset
7
+ dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
8
+
9
+ # Tokenizer
10
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
11
+ tokenizer.pad_token = tokenizer.eos_token
12
+
13
+ # Tiny GPT config (~20M params)
14
+ config = GPT2Config(
15
+ vocab_size=tokenizer.vocab_size,
16
+ n_positions=128,
17
+ n_ctx=128,
18
+ n_embd=256,
19
+ n_layer=4,
20
+ n_head=4
21
+ )
22
+ model = GPT2LMHeadModel(config)
23
+
24
+ # Tokenize dataset
25
+ def tokenize_function(examples):
26
+ return tokenizer(examples['text'], truncation=True, max_length=128, padding="max_length")
27
+
28
+ tokenized_datasets = dataset.map(tokenize_function, batched=True)
29
+ tokenized_datasets.set_format(type='torch', columns=['input_ids'])
30
+
31
+ # Data collator
32
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
33
+
34
+ # Training arguments
35
+ training_args = TrainingArguments(
36
+ output_dir="./tiny-gpt",
37
+ num_train_epochs=3,
38
+ per_device_train_batch_size=2,
39
+ save_steps=500,
40
+ save_total_limit=2,
41
+ logging_steps=50,
42
+ learning_rate=5e-4,
43
+ fp16=False
44
+ )
45
+
46
+ # Trainer
47
+ trainer = Trainer(
48
+ model=model,
49
+ args=training_args,
50
+ train_dataset=tokenized_datasets['train'],
51
+ tokenizer=tokenizer,
52
+ data_collator=data_collator
53
+ )
54
+
55
+ # Train model
56
+ trainer.train()
57
+
58
+ # Save model
59
+ model.save_pretrained("./tiny-gpt")
60
+ tokenizer.save_pretrained("./tiny-gpt")
61
+
62
+ print("Training complete! Model saved in ./tiny-gpt")