File size: 2,628 Bytes
03a7eb9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import argparse
import json
from pathlib import Path

from datasets import Dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    Trainer,
    TrainingArguments,
)


def load_sft(path: Path):
    rows = []
    with path.open("r", encoding="utf-8") as f:
        for line in f:
            if not line.strip():
                continue
            obj = json.loads(line)
            text = (
                "### Instruction\n"
                f"{obj['prompt']}\n\n"
                "### Response\n"
                f"{obj['response']}\n"
            )
            rows.append({"text": text})
    return rows


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--sft-data", default="ollama_rl_out/sft_dataset.jsonl")
    parser.add_argument("--model-name", default="distilgpt2")
    parser.add_argument("--output-dir", default="hf_sft_checkpoint")
    parser.add_argument("--max-steps", type=int, default=60)
    args = parser.parse_args()

    rows = load_sft(Path(args.sft_data))
    if not rows:
        raise ValueError(
            f"Empty SFT dataset at {args.sft_data}. Run rollout + dataset builder first and verify the dataset path."
        )
    print(f"Loaded {len(rows)} SFT examples from {args.sft_data}")
    dataset = Dataset.from_list(rows)

    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    def tok(batch):
        return tokenizer(
            batch["text"],
            truncation=True,
            max_length=384,
            padding="max_length",
        )

    tokenized = dataset.map(tok, batched=True, remove_columns=["text"])
    model = AutoModelForCausalLM.from_pretrained(args.model_name)

    train_args = TrainingArguments(
        output_dir=args.output_dir,
        max_steps=args.max_steps,
        per_device_train_batch_size=2,
        gradient_accumulation_steps=1,
        learning_rate=2e-5,
        logging_strategy="steps",
        logging_steps=10,
        save_strategy="steps",
        save_steps=10,
        save_total_limit=2,
        report_to=[],
        fp16=False,
        bf16=False,
    )

    trainer = Trainer(
        model=model,
        args=train_args,
        train_dataset=tokenized,
        data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False),
    )
    trainer.train()
    trainer.save_model(args.output_dir)
    tokenizer.save_pretrained(args.output_dir)
    print(f"saved_checkpoint={args.output_dir}")


if __name__ == "__main__":
    main()