Percy3822 commited on
Commit
7e32f1f
·
verified ·
1 Parent(s): 12e3c33

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +10 -5
train.py CHANGED
@@ -1,17 +1,19 @@
1
  import argparse, os, traceback
 
2
  from datasets import load_dataset
3
  from transformers import (
4
  AutoTokenizer, AutoModelForCausalLM,
5
  DataCollatorForLanguageModeling, Trainer, TrainingArguments
6
  )
7
 
8
- DONE = "TRAIN_DONE"
9
- ERRF = "TRAIN_ERROR"
 
10
 
11
  def parse_args():
12
  ap = argparse.ArgumentParser()
13
  ap.add_argument("--dataset", required=True)
14
- ap.add_argument("--output", default="trained_model")
15
  ap.add_argument("--model_name", default="Salesforce/codegen-350M-multi")
16
  ap.add_argument("--epochs", type=float, default=1.0)
17
  ap.add_argument("--batch_size", type=int, default=2)
@@ -26,6 +28,7 @@ def main():
26
  ds = load_dataset("json", data_files=a.dataset, split="train")
27
  cols = ds.column_names
28
  print("🧾 Columns:", cols, flush=True)
 
29
  if a.subset and a.subset > 0:
30
  ds = ds.select(range(min(a.subset, len(ds))))
31
  print(f"✂ Subset: {len(ds)} rows", flush=True)
@@ -71,12 +74,14 @@ def main():
71
  os.makedirs(a.output, exist_ok=True)
72
  trainer.save_model(a.output)
73
  tok.save_pretrained(a.output)
74
- open(DONE, "w").write("ok") # <—— signal file
75
  print("✅ Done.", flush=True)
76
 
77
  if __name__ == "__main__":
78
  try:
 
 
79
  main()
80
  except Exception:
81
- open(ERRF, "w").write(traceback.format_exc())
82
  raise
 
1
  import argparse, os, traceback
2
+ from pathlib import Path
3
  from datasets import load_dataset
4
  from transformers import (
5
  AutoTokenizer, AutoModelForCausalLM,
6
  DataCollatorForLanguageModeling, Trainer, TrainingArguments
7
  )
8
 
9
+ ROOT = Path(_file_).resolve().parent # /home/user/app
10
+ DONE = ROOT / "TRAIN_DONE" # <- write here
11
+ ERRF = ROOT / "TRAIN_ERROR"
12
 
13
  def parse_args():
14
  ap = argparse.ArgumentParser()
15
  ap.add_argument("--dataset", required=True)
16
+ ap.add_argument("--output", default=str(ROOT / "trained_model"))
17
  ap.add_argument("--model_name", default="Salesforce/codegen-350M-multi")
18
  ap.add_argument("--epochs", type=float, default=1.0)
19
  ap.add_argument("--batch_size", type=int, default=2)
 
28
  ds = load_dataset("json", data_files=a.dataset, split="train")
29
  cols = ds.column_names
30
  print("🧾 Columns:", cols, flush=True)
31
+
32
  if a.subset and a.subset > 0:
33
  ds = ds.select(range(min(a.subset, len(ds))))
34
  print(f"✂ Subset: {len(ds)} rows", flush=True)
 
74
  os.makedirs(a.output, exist_ok=True)
75
  trainer.save_model(a.output)
76
  tok.save_pretrained(a.output)
77
+ DONE.write_text("ok") # <- SIGNAL!
78
  print("✅ Done.", flush=True)
79
 
80
  if __name__ == "__main__":
81
  try:
82
+ DONE.unlink(missing_ok=True)
83
+ ERRF.unlink(missing_ok=True)
84
  main()
85
  except Exception:
86
+ ERRF.write_text(traceback.format_exc())
87
  raise