Spaces:
Sleeping
Sleeping
Update train.py
Browse files
train.py
CHANGED
|
@@ -30,14 +30,6 @@ def main():
|
|
| 30 |
cols = ds.column_names
|
| 31 |
print("🧾 Columns:", cols, flush=True)
|
| 32 |
|
| 33 |
-
# Accept either {"text": "..."} or {"prompt": "...", "completion": "..."}
|
| 34 |
-
def to_text(example):
|
| 35 |
-
if "text" in example:
|
| 36 |
-
return example["text"]
|
| 37 |
-
if "prompt" in example and "completion" in example:
|
| 38 |
-
return (str(example["prompt"]).rstrip() + "\n" + str(example["completion"]))
|
| 39 |
-
raise ValueError("Dataset must have 'text' or 'prompt' + 'completion'.")
|
| 40 |
-
|
| 41 |
if a.subset and a.subset > 0:
|
| 42 |
ds = ds.select(range(min(a.subset, len(ds))))
|
| 43 |
print(f"✂ Subset: {len(ds)} rows", flush=True)
|
|
@@ -48,8 +40,16 @@ def main():
|
|
| 48 |
tok.pad_token = tok.eos_token
|
| 49 |
model = AutoModelForCausalLM.from_pretrained(a.model_name)
|
| 50 |
|
|
|
|
| 51 |
def tokenize(batch):
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
return tok(texts, padding="max_length", truncation=True, max_length=a.block_size)
|
| 54 |
|
| 55 |
print("🔁 Tokenizing…", flush=True)
|
|
@@ -66,7 +66,7 @@ def main():
|
|
| 66 |
save_steps=200,
|
| 67 |
save_total_limit=1,
|
| 68 |
report_to=[],
|
| 69 |
-
fp16=False,
|
| 70 |
)
|
| 71 |
|
| 72 |
print("⚙ Trainer…", flush=True)
|
|
@@ -83,8 +83,4 @@ def main():
|
|
| 83 |
print("✅ Done.", flush=True)
|
| 84 |
|
| 85 |
if __name__ == "__main__":
|
| 86 |
-
|
| 87 |
-
main()
|
| 88 |
-
except Exception as e:
|
| 89 |
-
print(f"❌ Error during training: {e}", flush=True)
|
| 90 |
-
raise
|
|
|
|
| 30 |
cols = ds.column_names
|
| 31 |
print("🧾 Columns:", cols, flush=True)
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
if a.subset and a.subset > 0:
|
| 34 |
ds = ds.select(range(min(a.subset, len(ds))))
|
| 35 |
print(f"✂ Subset: {len(ds)} rows", flush=True)
|
|
|
|
| 40 |
tok.pad_token = tok.eos_token
|
| 41 |
model = AutoModelForCausalLM.from_pretrained(a.model_name)
|
| 42 |
|
| 43 |
+
# ✅ batched=True passes dict-of-lists
|
| 44 |
def tokenize(batch):
|
| 45 |
+
if "text" in batch:
|
| 46 |
+
texts = batch["text"]
|
| 47 |
+
elif "prompt" in batch and "completion" in batch:
|
| 48 |
+
prompts = batch["prompt"]
|
| 49 |
+
completions = batch["completion"]
|
| 50 |
+
texts = [(str(p).rstrip() + "\n" + str(c)) for p, c in zip(prompts, completions)]
|
| 51 |
+
else:
|
| 52 |
+
raise ValueError("Dataset must have 'text' or 'prompt' + 'completion'.")
|
| 53 |
return tok(texts, padding="max_length", truncation=True, max_length=a.block_size)
|
| 54 |
|
| 55 |
print("🔁 Tokenizing…", flush=True)
|
|
|
|
| 66 |
save_steps=200,
|
| 67 |
save_total_limit=1,
|
| 68 |
report_to=[],
|
| 69 |
+
fp16=False,
|
| 70 |
)
|
| 71 |
|
| 72 |
print("⚙ Trainer…", flush=True)
|
|
|
|
| 83 |
print("✅ Done.", flush=True)
|
| 84 |
|
| 85 |
if __name__ == "__main__":
|
| 86 |
+
main()
|
|
|
|
|
|
|
|
|
|
|
|