Percy3822 commited on
Commit
70ef65d
·
verified ·
1 Parent(s): a88d3e2

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +11 -15
train.py CHANGED
@@ -30,14 +30,6 @@ def main():
30
  cols = ds.column_names
31
  print("🧾 Columns:", cols, flush=True)
32
 
33
- # Accept either {"text": "..."} or {"prompt": "...", "completion": "..."}
34
- def to_text(example):
35
- if "text" in example:
36
- return example["text"]
37
- if "prompt" in example and "completion" in example:
38
- return (str(example["prompt"]).rstrip() + "\n" + str(example["completion"]))
39
- raise ValueError("Dataset must have 'text' or 'prompt' + 'completion'.")
40
-
41
  if a.subset and a.subset > 0:
42
  ds = ds.select(range(min(a.subset, len(ds))))
43
  print(f"✂ Subset: {len(ds)} rows", flush=True)
@@ -48,8 +40,16 @@ def main():
48
  tok.pad_token = tok.eos_token
49
  model = AutoModelForCausalLM.from_pretrained(a.model_name)
50
 
 
51
  def tokenize(batch):
52
- texts = [to_text(x) for x in batch]
 
 
 
 
 
 
 
53
  return tok(texts, padding="max_length", truncation=True, max_length=a.block_size)
54
 
55
  print("🔁 Tokenizing…", flush=True)
@@ -66,7 +66,7 @@ def main():
66
  save_steps=200,
67
  save_total_limit=1,
68
  report_to=[],
69
- fp16=False, # CPU-friendly in Spaces
70
  )
71
 
72
  print("⚙ Trainer…", flush=True)
@@ -83,8 +83,4 @@ def main():
83
  print("✅ Done.", flush=True)
84
 
85
  if __name__ == "__main__":
86
- try:
87
- main()
88
- except Exception as e:
89
- print(f"❌ Error during training: {e}", flush=True)
90
- raise
 
30
  cols = ds.column_names
31
  print("🧾 Columns:", cols, flush=True)
32
 
 
 
 
 
 
 
 
 
33
  if a.subset and a.subset > 0:
34
  ds = ds.select(range(min(a.subset, len(ds))))
35
  print(f"✂ Subset: {len(ds)} rows", flush=True)
 
40
  tok.pad_token = tok.eos_token
41
  model = AutoModelForCausalLM.from_pretrained(a.model_name)
42
 
43
+ # ✅ batched=True passes dict-of-lists
44
  def tokenize(batch):
45
+ if "text" in batch:
46
+ texts = batch["text"]
47
+ elif "prompt" in batch and "completion" in batch:
48
+ prompts = batch["prompt"]
49
+ completions = batch["completion"]
50
+ texts = [(str(p).rstrip() + "\n" + str(c)) for p, c in zip(prompts, completions)]
51
+ else:
52
+ raise ValueError("Dataset must have 'text' or 'prompt' + 'completion'.")
53
  return tok(texts, padding="max_length", truncation=True, max_length=a.block_size)
54
 
55
  print("🔁 Tokenizing…", flush=True)
 
66
  save_steps=200,
67
  save_total_limit=1,
68
  report_to=[],
69
+ fp16=False,
70
  )
71
 
72
  print("⚙ Trainer…", flush=True)
 
83
  print("✅ Done.", flush=True)
84
 
85
  if __name__ == "__main__":
86
+ main()