Spaces:
Sleeping
Sleeping
Update train.py
Browse files
train.py
CHANGED
|
@@ -1,17 +1,19 @@
|
|
| 1 |
import argparse, os, traceback
|
|
|
|
| 2 |
from datasets import load_dataset
|
| 3 |
from transformers import (
|
| 4 |
AutoTokenizer, AutoModelForCausalLM,
|
| 5 |
DataCollatorForLanguageModeling, Trainer, TrainingArguments
|
| 6 |
)
|
| 7 |
|
| 8 |
-
|
| 9 |
-
|
|
|
|
| 10 |
|
| 11 |
def parse_args():
|
| 12 |
ap = argparse.ArgumentParser()
|
| 13 |
ap.add_argument("--dataset", required=True)
|
| 14 |
-
ap.add_argument("--output", default="trained_model")
|
| 15 |
ap.add_argument("--model_name", default="Salesforce/codegen-350M-multi")
|
| 16 |
ap.add_argument("--epochs", type=float, default=1.0)
|
| 17 |
ap.add_argument("--batch_size", type=int, default=2)
|
|
@@ -26,6 +28,7 @@ def main():
|
|
| 26 |
ds = load_dataset("json", data_files=a.dataset, split="train")
|
| 27 |
cols = ds.column_names
|
| 28 |
print("🧾 Columns:", cols, flush=True)
|
|
|
|
| 29 |
if a.subset and a.subset > 0:
|
| 30 |
ds = ds.select(range(min(a.subset, len(ds))))
|
| 31 |
print(f"✂ Subset: {len(ds)} rows", flush=True)
|
|
@@ -71,12 +74,14 @@ def main():
|
|
| 71 |
os.makedirs(a.output, exist_ok=True)
|
| 72 |
trainer.save_model(a.output)
|
| 73 |
tok.save_pretrained(a.output)
|
| 74 |
-
|
| 75 |
print("✅ Done.", flush=True)
|
| 76 |
|
| 77 |
if __name__ == "__main__":
|
| 78 |
try:
|
|
|
|
|
|
|
| 79 |
main()
|
| 80 |
except Exception:
|
| 81 |
-
|
| 82 |
raise
|
|
|
|
| 1 |
import argparse, os, traceback
|
| 2 |
+
from pathlib import Path
|
| 3 |
from datasets import load_dataset
|
| 4 |
from transformers import (
|
| 5 |
AutoTokenizer, AutoModelForCausalLM,
|
| 6 |
DataCollatorForLanguageModeling, Trainer, TrainingArguments
|
| 7 |
)
|
| 8 |
|
| 9 |
+
ROOT = Path(_file_).resolve().parent # /home/user/app
|
| 10 |
+
DONE = ROOT / "TRAIN_DONE" # <- write here
|
| 11 |
+
ERRF = ROOT / "TRAIN_ERROR"
|
| 12 |
|
| 13 |
def parse_args():
|
| 14 |
ap = argparse.ArgumentParser()
|
| 15 |
ap.add_argument("--dataset", required=True)
|
| 16 |
+
ap.add_argument("--output", default=str(ROOT / "trained_model"))
|
| 17 |
ap.add_argument("--model_name", default="Salesforce/codegen-350M-multi")
|
| 18 |
ap.add_argument("--epochs", type=float, default=1.0)
|
| 19 |
ap.add_argument("--batch_size", type=int, default=2)
|
|
|
|
| 28 |
ds = load_dataset("json", data_files=a.dataset, split="train")
|
| 29 |
cols = ds.column_names
|
| 30 |
print("🧾 Columns:", cols, flush=True)
|
| 31 |
+
|
| 32 |
if a.subset and a.subset > 0:
|
| 33 |
ds = ds.select(range(min(a.subset, len(ds))))
|
| 34 |
print(f"✂ Subset: {len(ds)} rows", flush=True)
|
|
|
|
| 74 |
os.makedirs(a.output, exist_ok=True)
|
| 75 |
trainer.save_model(a.output)
|
| 76 |
tok.save_pretrained(a.output)
|
| 77 |
+
DONE.write_text("ok") # <- SIGNAL!
|
| 78 |
print("✅ Done.", flush=True)
|
| 79 |
|
| 80 |
if __name__ == "__main__":
|
| 81 |
try:
|
| 82 |
+
DONE.unlink(missing_ok=True)
|
| 83 |
+
ERRF.unlink(missing_ok=True)
|
| 84 |
main()
|
| 85 |
except Exception:
|
| 86 |
+
ERRF.write_text(traceback.format_exc())
|
| 87 |
raise
|