Percy3822 commited on
Commit
ca82b75
·
verified ·
1 Parent(s): 210cdd4

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +75 -0
train.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse, os, sys
2
+ from typing import List
3
+ from datasets import load_dataset
4
+ from transformers import (
5
+ AutoTokenizer, AutoModelForCausalLM,
6
+ DataCollatorForLanguageModeling, TrainingArguments, Trainer
7
+ )
8
+
9
+ def parse_args():
10
+ p = argparse.ArgumentParser()
11
+ p.add_argument("--dataset", required=True, help="JSONL (.jsonl or .jsonl.gz)")
12
+ p.add_argument("--output", default="trained_model")
13
+ p.add_argument("--model_name", default="distilgpt2") # tiny & quick
14
+ p.add_argument("--epochs", type=float, default=0.5) # short run
15
+ p.add_argument("--batch_size", type=int, default=2)
16
+ p.add_argument("--block_size", type=int, default=256)
17
+ p.add_argument("--learning_rate", type=float, default=5e-5)
18
+ return p.parse_args()
19
+
20
+ def main():
21
+ a = parse_args()
22
+ print(f"📥 Loading dataset: {a.dataset}", flush=True)
23
+ ds = load_dataset("json", data_files=a.dataset, split="train")
24
+ cols = ds.column_names
25
+ print("🧾 Columns:", cols, flush=True)
26
+
27
+ tok = AutoTokenizer.from_pretrained(a.model_name)
28
+ if tok.pad_token is None:
29
+ tok.pad_token = tok.eos_token
30
+ model = AutoModelForCausalLM.from_pretrained(a.model_name)
31
+
32
+ def build_texts(batch) -> List[str]:
33
+ if "text" in batch:
34
+ return [str(t) for t in batch["text"]]
35
+ if "prompt" in batch and "completion" in batch:
36
+ return [f"{str(p).rstrip()}\n{str(c)}" for p,c in zip(batch["prompt"], batch["completion"])]
37
+ raise ValueError("Dataset must contain 'text' OR both 'prompt' and 'completion'.")
38
+
39
+ def tokenize(batch):
40
+ texts = build_texts(batch)
41
+ return tok(texts, padding="max_length", truncation=True, max_length=a.block_size)
42
+
43
+ print("🔁 Tokenizing…", flush=True)
44
+ tokds = ds.map(tokenize, batched=True, remove_columns=cols)
45
+ collator = DataCollatorForLanguageModeling(tokenizer=tok, mlm=False)
46
+
47
+ print("⚙ Trainer…", flush=True)
48
+ args = TrainingArguments(
49
+ output_dir=a.output,
50
+ overwrite_output_dir=True,
51
+ per_device_train_batch_size=a.batch_size,
52
+ num_train_epochs=a.epochs,
53
+ learning_rate=a.learning_rate,
54
+ logging_steps=10,
55
+ save_steps=200,
56
+ save_total_limit=1,
57
+ report_to=[],
58
+ fp16=False,
59
+ )
60
+ trainer = Trainer(model=model, args=args, train_dataset=tokds, tokenizer=tok, data_collator=collator)
61
+
62
+ print("🚀 Training…", flush=True)
63
+ trainer.train()
64
+ print(f"💾 Saving to {a.output}", flush=True)
65
+ os.makedirs(a.output, exist_ok=True)
66
+ trainer.save_model(a.output)
67
+ tok.save_pretrained(a.output)
68
+ print("✅ Done.", flush=True)
69
+
70
+ if _name_ == "_main_":
71
+ try:
72
+ main()
73
+ except Exception as e:
74
+ print(f"❌ Training failed: {e}", flush=True)
75
+ sys.exit(1)