| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainer, Seq2SeqTrainingArguments |
| |
| tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2-large") |
| model = AutoModelForSeq2SeqLM.from_pretrained("Kaludi/chatgpt-gpt4-prompts-bart-large-cnn-samsum", from_tf=True) |
|
|
| |
| |
| |
| |
|
|
| |
| data_collator = DataCollatorForSeq2Seq(tokenizer, model=model) |
|
|
| |
| training_args = Seq2SeqTrainingArguments( |
| output_dir="./gpt4-text-gen", |
| overwrite_output_dir=True, |
| per_device_train_batch_size=4, |
| save_steps=10_000, |
| save_total_limit=2, |
| ) |
|
|
| |
| trainer = Seq2SeqTrainer( |
| model=model, |
| args=training_args, |
| data_collator=data_collator, |
| train_dataset=your_training_dataset, |
| ) |
|
|
| |
| trainer.train() |
|
|
| |
| model.save_pretrained("./gpt4-text-gen") |
| tokenizer.save_pretrained("./gpt4-text-gen") |
|
|
| |
| input_text = "Hello!" |
| input_ids = tokenizer.encode(input_text, return_tensors="pt") |
| output = model.generate(input_ids, max_length=100, num_return_sequences=1, no_repeat_ngram_size=2, top_k=50, top_p=0.95) |
| generated_text = tokenizer.decode(output[0], skip_special_tokens=True) |
| print("Generated Text: ", generated_text) |