Percy3822 commited on
Commit
a359ec7
·
verified ·
1 Parent(s): 579cbca

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +83 -25
train.py CHANGED
@@ -1,47 +1,105 @@
1
  import argparse
 
 
 
 
2
  from datasets import load_dataset
3
- from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
4
- import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  def main():
7
- parser = argparse.ArgumentParser()
8
- parser.add_argument("--dataset", type=str, required=True)
9
- args = parser.parse_args()
 
 
 
10
 
11
- print("📥 Loading dataset...")
12
- dataset = load_dataset("json", data_files=args.dataset, split="train")
 
 
13
 
14
- tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
15
- tokenizer.pad_token = tokenizer.eos_token
16
 
17
- def tokenize_function(examples):
18
- return tokenizer(examples["prompt"], truncation=True, padding="max_length", max_length=256)
 
 
 
 
 
19
 
20
- tokenized_dataset = dataset.map(tokenize_function, batched=True)
 
 
 
 
 
 
 
21
 
22
- print("📦 Loading model...")
23
- model = AutoModelForCausalLM.from_pretrained("distilgpt2")
 
 
 
 
 
 
24
 
 
25
  training_args = TrainingArguments(
26
- output_dir="./trained_model",
27
  overwrite_output_dir=True,
28
- num_train_epochs=1,
29
- per_device_train_batch_size=2,
30
- save_strategy="epoch",
31
- logging_dir="./logs",
32
  logging_steps=10,
33
- no_cuda=not torch.cuda.is_available()
 
 
 
 
34
  )
35
 
36
  trainer = Trainer(
37
  model=model,
38
  args=training_args,
39
- train_dataset=tokenized_dataset
 
 
40
  )
41
 
42
- print("🚀 Starting training...")
43
  trainer.train()
44
- print("✅ Training finished. Model saved to ./trained_model")
45
 
46
- if __name__ == "__main__":
47
- main()
 
 
 
 
 
 
 
 
 
 
 
 
1
  import argparse
2
+ import os
3
+ import sys
4
+ from typing import List
5
+
6
  from datasets import load_dataset
7
+ from transformers import (
8
+ AutoTokenizer,
9
+ AutoModelForCausalLM,
10
+ DataCollatorForLanguageModeling,
11
+ TrainingArguments,
12
+ Trainer,
13
+ )
14
+
15
+ def parse_args():
16
+ p = argparse.ArgumentParser()
17
+ p.add_argument("--dataset", required=True, help="Path to a JSON/JSONL (or .gz) file.")
18
+ p.add_argument("--output", default="trained_model", help="Folder to save the fine-tuned model.")
19
+ p.add_argument("--model_name", default="distilgpt2", help="Base model.")
20
+ p.add_argument("--epochs", type=float, default=1.0)
21
+ p.add_argument("--batch_size", type=int, default=2)
22
+ p.add_argument("--block_size", type=int, default=256)
23
+ p.add_argument("--learning_rate", type=float, default=5e-5)
24
+ return p.parse_args()
25
 
26
  def main():
27
+ args = parse_args()
28
+ print(f"📥 Loading dataset: {args.dataset}", flush=True)
29
+ ds = load_dataset("json", data_files=args.dataset, split="train")
30
+
31
+ cols = ds.column_names
32
+ print(f"🧾 Columns: {cols}", flush=True)
33
 
34
+ print(f"🧠 Loading model & tokenizer: {args.model_name}", flush=True)
35
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
36
+ if tokenizer.pad_token is None:
37
+ tokenizer.pad_token = tokenizer.eos_token # GPT-2 family has no pad token
38
 
39
+ model = AutoModelForCausalLM.from_pretrained(args.model_name)
 
40
 
41
+ def build_texts(batch) -> List[str]:
42
+ if "text" in batch:
43
+ return [str(t) for t in batch["text"]]
44
+ if "prompt" in batch and "completion" in batch:
45
+ # simple join: prompt + newline + completion
46
+ return [f"{str(p).rstrip()}\n{str(c)}" for p, c in zip(batch["prompt"], batch["completion"])]
47
+ raise ValueError("Dataset must contain 'text' OR both 'prompt' and 'completion' fields.")
48
 
49
+ def tokenize(batch):
50
+ texts = build_texts(batch)
51
+ return tokenizer(
52
+ texts,
53
+ padding="max_length",
54
+ truncation=True,
55
+ max_length=args.block_size,
56
+ )
57
 
58
+ print("🔁 Tokenizing…", flush=True)
59
+ tokenized = ds.map(
60
+ tokenize,
61
+ batched=True,
62
+ remove_columns=cols, # keep only tokenized fields
63
+ )
64
+
65
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
66
 
67
+ print("⚙ Preparing Trainer…", flush=True)
68
  training_args = TrainingArguments(
69
+ output_dir=args.output,
70
  overwrite_output_dir=True,
71
+ per_device_train_batch_size=args.batch_size,
72
+ num_train_epochs=args.epochs,
73
+ learning_rate=args.learning_rate,
 
74
  logging_steps=10,
75
+ save_steps=200, # frequent-ish checkpoints (kept to 1)
76
+ save_total_limit=1,
77
+ report_to=[],
78
+ gradient_accumulation_steps=1,
79
+ fp16=False, # CPU-friendly; enable if GPU has fp16
80
  )
81
 
82
  trainer = Trainer(
83
  model=model,
84
  args=training_args,
85
+ train_dataset=tokenized,
86
+ tokenizer=tokenizer,
87
+ data_collator=data_collator,
88
  )
89
 
90
+ print("🚀 Training…", flush=True)
91
  trainer.train()
 
92
 
93
+ print(f"💾 Saving to: {args.output}", flush=True)
94
+ os.makedirs(args.output, exist_ok=True)
95
+ trainer.save_model(args.output)
96
+ tokenizer.save_pretrained(args.output)
97
+ print("✅ Done.", flush=True)
98
+
99
+ if _name_ == "_main_":
100
+ try:
101
+ main()
102
+ except Exception as e:
103
+ # Make sure a failure returns non-zero so your app can detect it
104
+ print(f"❌ Training failed: {e}", flush=True)
105
+ sys.exit(1)