File size: 1,531 Bytes
078d71d
04a8e34
 
587575a
078d71d
 
 
 
04a8e34
078d71d
 
04a8e34
078d71d
 
 
 
cc4f041
078d71d
 
 
 
04a8e34
078d71d
 
01be04f
078d71d
04a8e34
 
01be04f
078d71d
 
cc4f041
 
078d71d
cc4f041
078d71d
04a8e34
 
 
 
01be04f
078d71d
04a8e34
078d71d
04a8e34
 
078d71d
04a8e34
078d71d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import argparse
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling

parser = argparse.ArgumentParser()
parser.add_argument("--dataset", required=True)
parser.add_argument("--output", default="trained_model")
args = parser.parse_args()

print("πŸ“Š Loading dataset...")
dataset = load_dataset("json", data_files=args.dataset, split="train")

print("🧠 Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained("distilgpt2")

# βœ… Clean, batch-safe tokenize
def tokenize(batch):
    full_texts = [str(p) + str(c) for p, c in zip(batch["prompt"], batch["completion"])]
    return tokenizer(full_texts, padding="max_length", truncation=True, max_length=256)

print("πŸ” Tokenizing...")
tokenized = dataset.map(tokenize, batched=True)

print("πŸ“¦ Setting up trainer...")
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

training_args = TrainingArguments(
    output_dir=args.output,
    per_device_train_batch_size=2,
    num_train_epochs=1,
    logging_steps=1,
    save_steps=5,
    save_total_limit=1,
    report_to=[]
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

print("πŸš€ Starting training...")
trainer.train()
trainer.save_model(args.output)
print("βœ… Done.")