Spaces:
Sleeping
Sleeping
File size: 1,427 Bytes
078d71d 04a8e34 86cb75d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
import argparse
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
import torch
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, required=True)
args = parser.parse_args()
print("π₯ Loading dataset...")
dataset = load_dataset("json", data_files=args.dataset, split="train")
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
tokenizer.pad_token = tokenizer.eos_token
def tokenize_function(examples):
return tokenizer(examples["prompt"], truncation=True, padding="max_length", max_length=256)
tokenized_dataset = dataset.map(tokenize_function, batched=True)
print("π¦ Loading model...")
model = AutoModelForCausalLM.from_pretrained("distilgpt2")
training_args = TrainingArguments(
output_dir="./trained_model",
overwrite_output_dir=True,
num_train_epochs=1,
per_device_train_batch_size=2,
save_strategy="epoch",
logging_dir="./logs",
logging_steps=10,
no_cuda=not torch.cuda.is_available()
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset
)
print("π Starting training...")
trainer.train()
print("β
Training finished. Model saved to ./trained_model")
if __name__ == "__main__":
main() |