Spaces:
Sleeping
Sleeping
Update train.py
Browse files
train.py
CHANGED
|
@@ -7,13 +7,11 @@ from transformers import (
|
|
| 7 |
)
|
| 8 |
import zipfile
|
| 9 |
|
| 10 |
-
ROOT = Path(__file__).resolve().parent
|
| 11 |
-
|
| 12 |
def parse_args():
|
| 13 |
ap = argparse.ArgumentParser()
|
| 14 |
-
ap.add_argument("--dataset", required=True, help="Path to .jsonl
|
| 15 |
-
ap.add_argument("--output",
|
| 16 |
-
ap.add_argument("--zip_path",
|
| 17 |
ap.add_argument("--model_name", default="Salesforce/codegen-350M-multi")
|
| 18 |
ap.add_argument("--epochs", type=float, default=1.0)
|
| 19 |
ap.add_argument("--batch_size", type=int, default=2)
|
|
@@ -25,6 +23,7 @@ def main():
|
|
| 25 |
a = parse_args()
|
| 26 |
out_dir = Path(a.output).resolve()
|
| 27 |
zip_path = Path(a.zip_path).resolve()
|
|
|
|
| 28 |
|
| 29 |
print(f"📦 Loading dataset from: {a.dataset}", flush=True)
|
| 30 |
ds = load_dataset("json", data_files=a.dataset, split="train")
|
|
@@ -36,13 +35,15 @@ def main():
|
|
| 36 |
tok.pad_token = tok.eos_token
|
| 37 |
model = AutoModelForCausalLM.from_pretrained(a.model_name)
|
| 38 |
|
| 39 |
-
def
|
| 40 |
if "text" in batch:
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
| 46 |
return tok(texts, padding="max_length", truncation=True, max_length=a.block_size)
|
| 47 |
|
| 48 |
print("🔁 Tokenizing…", flush=True)
|
|
|
|
| 7 |
)
|
| 8 |
import zipfile
|
| 9 |
|
|
|
|
|
|
|
| 10 |
def parse_args():
|
| 11 |
ap = argparse.ArgumentParser()
|
| 12 |
+
ap.add_argument("--dataset", required=True, help="Path to .jsonl")
|
| 13 |
+
ap.add_argument("--output", required=True, help="Output model folder")
|
| 14 |
+
ap.add_argument("--zip_path", required=True, help="Path to write .zip")
|
| 15 |
ap.add_argument("--model_name", default="Salesforce/codegen-350M-multi")
|
| 16 |
ap.add_argument("--epochs", type=float, default=1.0)
|
| 17 |
ap.add_argument("--batch_size", type=int, default=2)
|
|
|
|
| 23 |
a = parse_args()
|
| 24 |
out_dir = Path(a.output).resolve()
|
| 25 |
zip_path = Path(a.zip_path).resolve()
|
| 26 |
+
out_dir.parent.mkdir(parents=True, exist_ok=True)
|
| 27 |
|
| 28 |
print(f"📦 Loading dataset from: {a.dataset}", flush=True)
|
| 29 |
ds = load_dataset("json", data_files=a.dataset, split="train")
|
|
|
|
| 35 |
tok.pad_token = tok.eos_token
|
| 36 |
model = AutoModelForCausalLM.from_pretrained(a.model_name)
|
| 37 |
|
| 38 |
+
def to_text(batch):
|
| 39 |
if "text" in batch:
|
| 40 |
+
return batch["text"]
|
| 41 |
+
if "prompt" in batch and "completion" in batch:
|
| 42 |
+
return [str(p).rstrip() + "\n" + str(c) for p, c in zip(batch["prompt"], batch["completion"])]
|
| 43 |
+
raise ValueError("Dataset must have 'text' or 'prompt' + 'completion'.")
|
| 44 |
+
|
| 45 |
+
def tokenize(batch):
|
| 46 |
+
texts = to_text(batch)
|
| 47 |
return tok(texts, padding="max_length", truncation=True, max_length=a.block_size)
|
| 48 |
|
| 49 |
print("🔁 Tokenizing…", flush=True)
|