File size: 3,551 Bytes
078d71d
a359ec7
 
 
 
04a8e34
a359ec7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86cb75d
 
a359ec7
 
 
 
 
 
86cb75d
a359ec7
 
 
 
86cb75d
a359ec7
86cb75d
a359ec7
 
 
 
 
 
 
86cb75d
a359ec7
 
 
 
 
 
 
 
86cb75d
a359ec7
 
 
 
 
 
 
 
86cb75d
a359ec7
86cb75d
a359ec7
86cb75d
a359ec7
 
 
86cb75d
a359ec7
 
 
 
 
86cb75d
 
 
 
 
a359ec7
 
 
86cb75d
 
a359ec7
86cb75d
 
a359ec7
 
 
 
 
 
cca05bf
a359ec7
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import argparse
import os
import sys
from typing import List

from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    DataCollatorForLanguageModeling,
    TrainingArguments,
    Trainer,
)

def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--dataset", required=True, help="Path to a JSON/JSONL (or .gz) file.")
    p.add_argument("--output", default="trained_model", help="Folder to save the fine-tuned model.")
    p.add_argument("--model_name", default="distilgpt2", help="Base model.")
    p.add_argument("--epochs", type=float, default=1.0)
    p.add_argument("--batch_size", type=int, default=2)
    p.add_argument("--block_size", type=int, default=256)
    p.add_argument("--learning_rate", type=float, default=5e-5)
    return p.parse_args()

def main():
    args = parse_args()
    print(f"📥 Loading dataset: {args.dataset}", flush=True)
    ds = load_dataset("json", data_files=args.dataset, split="train")

    cols = ds.column_names
    print(f"🧾 Columns: {cols}", flush=True)

    print(f"🧠 Loading model & tokenizer: {args.model_name}", flush=True)
    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token  # GPT-2 family has no pad token

    model = AutoModelForCausalLM.from_pretrained(args.model_name)

    def build_texts(batch) -> List[str]:
        if "text" in batch:
            return [str(t) for t in batch["text"]]
        if "prompt" in batch and "completion" in batch:
            # simple join: prompt + newline + completion
            return [f"{str(p).rstrip()}\n{str(c)}" for p, c in zip(batch["prompt"], batch["completion"])]
        raise ValueError("Dataset must contain 'text' OR both 'prompt' and 'completion' fields.")

    def tokenize(batch):
        texts = build_texts(batch)
        return tokenizer(
            texts,
            padding="max_length",
            truncation=True,
            max_length=args.block_size,
        )

    print("🔁 Tokenizing…", flush=True)
    tokenized = ds.map(
        tokenize,
        batched=True,
        remove_columns=cols,  # keep only tokenized fields
    )

    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

    print("⚙ Preparing Trainer…", flush=True)
    training_args = TrainingArguments(
        output_dir=args.output,
        overwrite_output_dir=True,
        per_device_train_batch_size=args.batch_size,
        num_train_epochs=args.epochs,
        learning_rate=args.learning_rate,
        logging_steps=10,
        save_steps=200,           # frequent-ish checkpoints (kept to 1)
        save_total_limit=1,
        report_to=[],
        gradient_accumulation_steps=1,
        fp16=False,               # CPU-friendly; enable if GPU has fp16
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized,
        tokenizer=tokenizer,
        data_collator=data_collator,
    )

    print("🚀 Training…", flush=True)
    trainer.train()

    print(f"💾 Saving to: {args.output}", flush=True)
    os.makedirs(args.output, exist_ok=True)
    trainer.save_model(args.output)
    tokenizer.save_pretrained(args.output)
    print("✅ Done.", flush=True)

if __name__ == "__main__":
    try:
        main()
    except Exception as e:
        # Make sure a failure returns non-zero so your app can detect it
        print(f"❌ Training failed: {e}", flush=True)
        sys.exit(1)