Percy3822 commited on
Commit
f5d9b2e
·
verified ·
1 Parent(s): 2c11c52

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +24 -8
train.py CHANGED
@@ -8,22 +8,37 @@ from transformers import (
8
 
9
  def parse_args():
10
  p = argparse.ArgumentParser()
11
- p.add_argument("--dataset", required=True, help="JSONL (.jsonl or .jsonl.gz)")
12
  p.add_argument("--output", default="trained_model")
13
- p.add_argument("--model_name", default="distilgpt2") # tiny & quick
14
- p.add_argument("--epochs", type=float, default=0.5) # short run
15
  p.add_argument("--batch_size", type=int, default=2)
16
  p.add_argument("--block_size", type=int, default=256)
17
  p.add_argument("--learning_rate", type=float, default=5e-5)
 
 
 
 
18
  return p.parse_args()
19
 
20
  def main():
21
  a = parse_args()
 
 
 
 
 
 
 
22
  print(f"📥 Loading dataset: {a.dataset}", flush=True)
23
  ds = load_dataset("json", data_files=a.dataset, split="train")
24
  cols = ds.column_names
25
  print("🧾 Columns:", cols, flush=True)
26
 
 
 
 
 
27
  tok = AutoTokenizer.from_pretrained(a.model_name)
28
  if tok.pad_token is None:
29
  tok.pad_token = tok.eos_token
@@ -33,7 +48,7 @@ def main():
33
  if "text" in batch:
34
  return [str(t) for t in batch["text"]]
35
  if "prompt" in batch and "completion" in batch:
36
- return [f"{str(p).rstrip()}\n{str(c)}" for p,c in zip(batch["prompt"], batch["completion"])]
37
  raise ValueError("Dataset must contain 'text' OR both 'prompt' and 'completion'.")
38
 
39
  def tokenize(batch):
@@ -49,13 +64,14 @@ def main():
49
  output_dir=a.output,
50
  overwrite_output_dir=True,
51
  per_device_train_batch_size=a.batch_size,
52
- num_train_epochs=a.epochs,
53
  learning_rate=a.learning_rate,
54
- logging_steps=10,
55
- save_steps=200,
56
  save_total_limit=1,
57
  report_to=[],
58
  fp16=False,
 
59
  )
60
  trainer = Trainer(model=model, args=args, train_dataset=tokds, tokenizer=tok, data_collator=collator)
61
 
@@ -67,7 +83,7 @@ def main():
67
  tok.save_pretrained(a.output)
68
  print("✅ Done.", flush=True)
69
 
70
- if __name__ == "__main__":
71
  try:
72
  main()
73
  except Exception as e:
 
8
 
9
  def parse_args():
10
  p = argparse.ArgumentParser()
11
+ p.add_argument("--dataset", required=True, help="JSON/JSONL (.jsonl or .jsonl.gz)")
12
  p.add_argument("--output", default="trained_model")
13
+ p.add_argument("--model_name", default="distilgpt2")
14
+ p.add_argument("--epochs", type=float, default=0.5)
15
  p.add_argument("--batch_size", type=int, default=2)
16
  p.add_argument("--block_size", type=int, default=256)
17
  p.add_argument("--learning_rate", type=float, default=5e-5)
18
+ # quick mode:
19
+ p.add_argument("--quick", type=int, default=0) # 1 => tiny model + fast
20
+ p.add_argument("--max_steps", type=int, default=0) # >0 overrides epochs
21
+ p.add_argument("--subset", type=int, default=0) # use first N rows
22
  return p.parse_args()
23
 
24
  def main():
25
  a = parse_args()
26
+
27
+ if a.quick:
28
+ a.model_name = "sshleifer/tiny-gpt2" # ultra-tiny, very fast
29
+ if a.max_steps <= 0: a.max_steps = 8
30
+ if a.subset <= 0: a.subset = 32
31
+ a.epochs = 1.0
32
+
33
  print(f"📥 Loading dataset: {a.dataset}", flush=True)
34
  ds = load_dataset("json", data_files=a.dataset, split="train")
35
  cols = ds.column_names
36
  print("🧾 Columns:", cols, flush=True)
37
 
38
+ if a.subset and a.subset > 0:
39
+ ds = ds.select(range(min(a.subset, len(ds))))
40
+ print(f"✂ Using subset: {len(ds)} rows", flush=True)
41
+
42
  tok = AutoTokenizer.from_pretrained(a.model_name)
43
  if tok.pad_token is None:
44
  tok.pad_token = tok.eos_token
 
48
  if "text" in batch:
49
  return [str(t) for t in batch["text"]]
50
  if "prompt" in batch and "completion" in batch:
51
+ return [f"{str(p).rstrip()}\n{str(c)}" for p, c in zip(batch["prompt"], batch["completion"])]
52
  raise ValueError("Dataset must contain 'text' OR both 'prompt' and 'completion'.")
53
 
54
  def tokenize(batch):
 
64
  output_dir=a.output,
65
  overwrite_output_dir=True,
66
  per_device_train_batch_size=a.batch_size,
67
+ num_train_epochs=a.epochs if a.max_steps == 0 else 1,
68
  learning_rate=a.learning_rate,
69
+ logging_steps=1,
70
+ save_steps=50,
71
  save_total_limit=1,
72
  report_to=[],
73
  fp16=False,
74
+ max_steps=a.max_steps if a.max_steps > 0 else -1,
75
  )
76
  trainer = Trainer(model=model, args=args, train_dataset=tokds, tokenizer=tok, data_collator=collator)
77
 
 
83
  tok.save_pretrained(a.output)
84
  print("✅ Done.", flush=True)
85
 
86
+ if _name_ == "_main_":
87
  try:
88
  main()
89
  except Exception as e: