Spaces:
Sleeping
Sleeping
| import argparse | |
| from datasets import load_dataset | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--dataset", required=True) | |
| parser.add_argument("--output", default="trained_model") | |
| args = parser.parse_args() | |
| print("π Loading dataset...") | |
| dataset = load_dataset("json", data_files=args.dataset, split="train") | |
| print("π§ Loading model and tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained("distilgpt2") | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForCausalLM.from_pretrained("distilgpt2") | |
| # β Clean, batch-safe tokenize | |
| def tokenize(batch): | |
| full_texts = [str(p) + str(c) for p, c in zip(batch["prompt"], batch["completion"])] | |
| return tokenizer(full_texts, padding="max_length", truncation=True, max_length=256) | |
| print("π Tokenizing...") | |
| tokenized = dataset.map(tokenize, batched=True) | |
| print("π¦ Setting up trainer...") | |
| data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) | |
| training_args = TrainingArguments( | |
| output_dir=args.output, | |
| per_device_train_batch_size=2, | |
| num_train_epochs=1, | |
| logging_steps=1, | |
| save_steps=5, | |
| save_total_limit=1, | |
| report_to=[] | |
| ) | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=tokenized, | |
| tokenizer=tokenizer, | |
| data_collator=data_collator, | |
| ) | |
| print("π Starting training...") | |
| trainer.train() | |
| trainer.save_model(args.output) | |
| print("β Done.") |