Percy3822 commited on
Commit
86cb75d
Β·
verified Β·
1 Parent(s): 330f9ce

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +45 -48
train.py CHANGED
@@ -1,50 +1,47 @@
1
  import argparse
2
  from datasets import load_dataset
3
- from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
4
-
5
- parser = argparse.ArgumentParser()
6
- parser.add_argument("--dataset", required=True)
7
- parser.add_argument("--output", default="trained_model")
8
- args = parser.parse_args()
9
-
10
- print("πŸ“Š Loading dataset...")
11
- dataset = load_dataset("json", data_files=args.dataset, split="train")
12
-
13
- print("🧠 Loading model and tokenizer...")
14
- tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
15
- tokenizer.pad_token = tokenizer.eos_token
16
- model = AutoModelForCausalLM.from_pretrained("distilgpt2")
17
-
18
- # βœ… Clean, batch-safe tokenize
19
- def tokenize(batch):
20
- full_texts = [str(p) + str(c) for p, c in zip(batch["prompt"], batch["completion"])]
21
- return tokenizer(full_texts, padding="max_length", truncation=True, max_length=256)
22
-
23
- print("πŸ” Tokenizing...")
24
- tokenized = dataset.map(tokenize, batched=True)
25
-
26
- print("πŸ“¦ Setting up trainer...")
27
- data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
28
-
29
- training_args = TrainingArguments(
30
- output_dir=args.output,
31
- per_device_train_batch_size=2,
32
- num_train_epochs=1,
33
- logging_steps=1,
34
- save_steps=5,
35
- save_total_limit=1,
36
- report_to=[]
37
- )
38
-
39
- trainer = Trainer(
40
- model=model,
41
- args=training_args,
42
- train_dataset=tokenized,
43
- tokenizer=tokenizer,
44
- data_collator=data_collator,
45
- )
46
-
47
- print("πŸš€ Starting training...")
48
- trainer.train()
49
- trainer.save_model(args.output)
50
- print("βœ… Done.")
 
1
  import argparse
2
  from datasets import load_dataset
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
4
+ import torch
5
+
6
+ def main():
7
+ parser = argparse.ArgumentParser()
8
+ parser.add_argument("--dataset", type=str, required=True)
9
+ args = parser.parse_args()
10
+
11
+ print("πŸ“₯ Loading dataset...")
12
+ dataset = load_dataset("json", data_files=args.dataset, split="train")
13
+
14
+ tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
15
+ tokenizer.pad_token = tokenizer.eos_token
16
+
17
+ def tokenize_function(examples):
18
+ return tokenizer(examples["prompt"], truncation=True, padding="max_length", max_length=256)
19
+
20
+ tokenized_dataset = dataset.map(tokenize_function, batched=True)
21
+
22
+ print("πŸ“¦ Loading model...")
23
+ model = AutoModelForCausalLM.from_pretrained("distilgpt2")
24
+
25
+ training_args = TrainingArguments(
26
+ output_dir="./trained_model",
27
+ overwrite_output_dir=True,
28
+ num_train_epochs=1,
29
+ per_device_train_batch_size=2,
30
+ save_strategy="epoch",
31
+ logging_dir="./logs",
32
+ logging_steps=10,
33
+ no_cuda=not torch.cuda.is_available()
34
+ )
35
+
36
+ trainer = Trainer(
37
+ model=model,
38
+ args=training_args,
39
+ train_dataset=tokenized_dataset
40
+ )
41
+
42
+ print("πŸš€ Starting training...")
43
+ trainer.train()
44
+ print("βœ… Training finished. Model saved to ./trained_model")
45
+
46
+ if __name__ == "__main__":
47
+ main()