Spaces:
Sleeping
Sleeping
File size: 1,531 Bytes
078d71d 04a8e34 587575a 078d71d 04a8e34 078d71d 04a8e34 078d71d cc4f041 078d71d 04a8e34 078d71d 01be04f 078d71d 04a8e34 01be04f 078d71d cc4f041 078d71d cc4f041 078d71d 04a8e34 01be04f 078d71d 04a8e34 078d71d 04a8e34 078d71d 04a8e34 078d71d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import argparse
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", required=True)
parser.add_argument("--output", default="trained_model")
args = parser.parse_args()
print("π Loading dataset...")
dataset = load_dataset("json", data_files=args.dataset, split="train")
print("π§ Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained("distilgpt2")
# β
Clean, batch-safe tokenize
def tokenize(batch):
full_texts = [str(p) + str(c) for p, c in zip(batch["prompt"], batch["completion"])]
return tokenizer(full_texts, padding="max_length", truncation=True, max_length=256)
print("π Tokenizing...")
tokenized = dataset.map(tokenize, batched=True)
print("π¦ Setting up trainer...")
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
training_args = TrainingArguments(
output_dir=args.output,
per_device_train_batch_size=2,
num_train_epochs=1,
logging_steps=1,
save_steps=5,
save_total_limit=1,
report_to=[]
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized,
tokenizer=tokenizer,
data_collator=data_collator,
)
print("π Starting training...")
trainer.train()
trainer.save_model(args.output)
print("β
Done.") |