| import transformers |
| from datasets import ClassLabel |
| import random |
| import pandas as pd |
|
|
|
|
| def tokenize_function(examples): |
| return tokenizer(examples['text'], add_special_tokens=True) |
|
|
|
|
| def group_texts(examples): |
| |
| concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} |
| total_length = len(concatenated_examples[list(examples.keys())[0]]) |
| |
| |
| total_length = (total_length // block_size) * block_size |
| |
| result = { |
| k: [t[i : i + block_size] for i in range(0, total_length, block_size)] |
| for k, t in concatenated_examples.items() |
| } |
| result["labels"] = result["input_ids"].copy() |
| return result |
|
|
|
|
|
|
| block_size = 128 |
|
|
| from datasets import load_dataset |
| datasets = load_dataset('jed351/cantonese-wikipedia') |
|
|
| from transformers import AutoTokenizer |
| model_checkpoint = "Ayaka/bart-base-cantonese" |
|
|
| tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True) |
| tokenized_datasets = datasets.map(tokenize_function, |
| batched=True, num_proc=4, remove_columns=["text"]) |
|
|
|
|
|
|
| lm_datasets = tokenized_datasets.map( |
| group_texts, |
| batched=True, |
| batch_size=1000, |
| num_proc=4, |
| ) |
|
|
|
|
|
|
| from transformers import Trainer, TrainingArguments |
|
|
|
|
| from transformers import DataCollatorForLanguageModeling |
| data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15) |
|
|
|
|
|
|
|
|
| from transformers import AutoModelForMaskedLM |
| model = AutoModelForMaskedLM.from_pretrained(model_checkpoint) |
|
|
|
|
| training_args = TrainingArguments( |
| f"bart-finetuned-wikitext2", |
| evaluation_strategy = "epoch", |
| learning_rate=2e-5, |
| weight_decay=0.01, |
| push_to_hub=False, |
| per_device_train_batch_size=72, |
| fp16=True, |
| save_steps=5000 |
| ) |
|
|
|
|
| trainer = Trainer( |
| model=model, |
| args=training_args, |
| train_dataset=lm_datasets["train"], |
| eval_dataset=lm_datasets["test"], |
| data_collator=data_collator, |
| ) |
|
|
|
|
| trainer.train() |
|
|