PoetGPT / README.md
luannd
change prefix to ngay xuan
131277c
metadata
tags:
  - Luc bat poem
  - GPT2
  - Causal-lm
model-index:
  - name: PoetGPT_Vietnamese
    results: []
language:
  - vi
widget:
  - text: <|startoftext|>ngày xuân
    example_title: ngày xuân

Lục Bát AI Poet

How to generation with prompt?

Type:
<|startoftext|> + your_prompt 
Examples:
<|startoftext|>ngày xuân

Usage:

import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel

tokenizer = GPT2Tokenizer.from_pretrained('luanngo/PoetGPT')
model = GPT2LMHeadModel.from_pretrained('luanngo/PoetGPT')

prompt = "<|startoftext|>" + "ngày xuân"
input_ids = tokenizer.encode(prompt.lower(), return_tensors='pt')
max_length = 100

outputs = model.generate(input_ids,
                          do_sample=True,
                          max_length=max_length,
                          top_p=0.95,
                          temperature=1.0,
                          repetition_penalty=10.0,
                          num_beams=5,
                          early_stopping=True,
                          num_return_sequences=3)

for i, output in enumerate(outputs):
    print(">> Generated text {}\n\n{}".format(i+1, 
              tokenizer.decode(output.tolist())))
    print('\n---')

Training hyperparameters

The following hyperparameters were used during training:

  • learning_rate: 3e-4
  • train_batch_size: 64
  • eval_batch_size: 64
  • seed: 37
  • distributed_type: multi-GPU
  • num_devices: 2
  • total_train_batch_size: 128
  • total_eval_batch_size: 128
  • optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
  • lr_scheduler_type: linear
  • num_epochs: 30
  • mixed_precision_training: Native AMP

Training results

Training Loss Epoch Step Validation Loss
2.6633 1.0 704 2.6718
2.5567 2.0 1408 2.5756
2.4885 3.0 2112 2.5283
2.3552 4.0 2816 2.4395
2.3084 5.0 3520 2.3811
2.2587 6.0 4224 2.3699
2.1938 7.0 4928 2.3470
2.1491 8.0 5632 2.3225
2.0623 9.0 6336 2.3276
2.0672 10.0 7040 2.3301
2.0293 11.0 7744 2.3186
1.9694 12.0 8448 2.3331
1.8658 13.0 9152 2.3565
1.8558 14.0 9856 2.3592

Framework versions

  • Transformers 4.22.2
  • Pytorch 1.13.0+cu117
  • Datasets 2.6.1
  • Tokenizers 0.12.1