Spaces:
Sleeping
Sleeping
File size: 2,628 Bytes
03a7eb9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 | import argparse
import json
from pathlib import Path
from datasets import Dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
DataCollatorForLanguageModeling,
Trainer,
TrainingArguments,
)
def load_sft(path: Path):
rows = []
with path.open("r", encoding="utf-8") as f:
for line in f:
if not line.strip():
continue
obj = json.loads(line)
text = (
"### Instruction\n"
f"{obj['prompt']}\n\n"
"### Response\n"
f"{obj['response']}\n"
)
rows.append({"text": text})
return rows
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--sft-data", default="ollama_rl_out/sft_dataset.jsonl")
parser.add_argument("--model-name", default="distilgpt2")
parser.add_argument("--output-dir", default="hf_sft_checkpoint")
parser.add_argument("--max-steps", type=int, default=60)
args = parser.parse_args()
rows = load_sft(Path(args.sft_data))
if not rows:
raise ValueError(
f"Empty SFT dataset at {args.sft_data}. Run rollout + dataset builder first and verify the dataset path."
)
print(f"Loaded {len(rows)} SFT examples from {args.sft_data}")
dataset = Dataset.from_list(rows)
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
def tok(batch):
return tokenizer(
batch["text"],
truncation=True,
max_length=384,
padding="max_length",
)
tokenized = dataset.map(tok, batched=True, remove_columns=["text"])
model = AutoModelForCausalLM.from_pretrained(args.model_name)
train_args = TrainingArguments(
output_dir=args.output_dir,
max_steps=args.max_steps,
per_device_train_batch_size=2,
gradient_accumulation_steps=1,
learning_rate=2e-5,
logging_strategy="steps",
logging_steps=10,
save_strategy="steps",
save_steps=10,
save_total_limit=2,
report_to=[],
fp16=False,
bf16=False,
)
trainer = Trainer(
model=model,
args=train_args,
train_dataset=tokenized,
data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False),
)
trainer.train()
trainer.save_model(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
print(f"saved_checkpoint={args.output_dir}")
if __name__ == "__main__":
main()
|