import argparse from datasets import load_dataset from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments import torch def main(): parser = argparse.ArgumentParser() parser.add_argument("--dataset", type=str, required=True) args = parser.parse_args() print("📥 Loading dataset...") dataset = load_dataset("json", data_files=args.dataset, split="train") tokenizer = AutoTokenizer.from_pretrained("distilgpt2") tokenizer.pad_token = tokenizer.eos_token def tokenize_function(examples): return tokenizer(examples["prompt"], truncation=True, padding="max_length", max_length=256) tokenized_dataset = dataset.map(tokenize_function, batched=True) print("📦 Loading model...") model = AutoModelForCausalLM.from_pretrained("distilgpt2") training_args = TrainingArguments( output_dir="./trained_model", overwrite_output_dir=True, num_train_epochs=1, per_device_train_batch_size=2, save_strategy="epoch", logging_dir="./logs", logging_steps=10, no_cuda=not torch.cuda.is_available() ) trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_dataset ) print("🚀 Starting training...") trainer.train() print("✅ Training finished. Model saved to ./trained_model") if __name__ == "__main__": main()