| |
| |
|
|
| from datasets import load_dataset |
| from transformers import ( |
| BertTokenizerFast, |
| BertConfig, |
| BertForMaskedLM, |
| DataCollatorForLanguageModeling, |
| Trainer, |
| TrainingArguments |
| ) |
| from evaluate import load |
| import torch |
| import yaml |
| import numpy as np |
|
|
| |
| |
| |
| |
|
|
|
|
|
|
| |
| def split_tab(examples): |
| heb, phon = [], [] |
| for line in examples["text"]: |
| h, p = line.split("\t") |
| heb.append(h) |
| phon.append(p) |
| return {"hebrew": heb, "phonemes": phon} |
|
|
| ds_train = ds_train.map(split_tab, batched=True, remove_columns=["text"]) |
| ds_eval = ds_eval.map(split_tab, batched=True, remove_columns=["text"]) |
|
|
| |
| punctuation = ';:,.!?¡¿—…"«»“” ’' |
| letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" |
| letters_ipa = ( |
| "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯ혂ŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘̩ᵻ" |
| ) |
| specials = ["[PAD]","[UNK]","[CLS]","[SEP]","[MASK]"] |
| all_chars = list(punctuation) + list(letters) + list(letters_ipa) |
| vocab_chars = [] |
| seen = set() |
| for c in all_chars: |
| if c not in seen: |
| seen.add(c) |
| vocab_chars.append(c) |
|
|
| with open("vocab.txt", "w", encoding="utf-8") as vf: |
| for tok in specials: |
| vf.write(tok + "\n") |
| for c in vocab_chars: |
| vf.write(c + "\n") |
|
|
| |
| tokenizer = BertTokenizerFast( |
| vocab_file="vocab.txt", |
| unk_token="[UNK]", pad_token="[PAD]", |
| cls_token="[CLS]", sep_token="[SEP]", |
| mask_token="[MASK]", |
| do_lower_case=False, |
| strip_accents=False, |
| tokenize_chinese_chars=False, |
| ) |
|
|
| config = BertConfig( |
| vocab_size = len(tokenizer), |
| hidden_size = 768, |
| num_hidden_layers = 12, |
| num_attention_heads = 12, |
| intermediate_size = 2048, |
| max_position_embeddings = 512, |
| hidden_dropout_prob = 0.1, |
| attention_probs_dropout_prob= 0.1, |
| ) |
|
|
| |
| def tokenize_fn(examples): |
| return tokenizer( |
| examples["phonemes"], |
| return_attention_mask=True, |
| add_special_tokens=True, |
| ) |
|
|
| tokenized_train = ds_train.map( |
| tokenize_fn, |
| batched=True, |
| remove_columns=["hebrew","phonemes"] |
| ) |
| tokenized_eval = ds_eval.map( |
| tokenize_fn, |
| batched=True, |
| remove_columns=["hebrew","phonemes"] |
| ) |
|
|
| |
| block_size = 128 |
| def group_texts(examples): |
| all_ids = sum(examples["input_ids"], []) |
| result = {"input_ids":[], "attention_mask":[]} |
| for i in range(0, len(all_ids) - block_size + 1, block_size): |
| chunk = all_ids[i : i + block_size] |
| result["input_ids"].append(chunk) |
| result["attention_mask"].append([1] * block_size) |
| return result |
|
|
| lm_train = tokenized_train.map( |
| group_texts, |
| batched=True, |
| remove_columns=list(tokenized_train.column_names), |
| ) |
| lm_eval = tokenized_eval.map( |
| group_texts, |
| batched=True, |
| remove_columns=list(tokenized_eval.column_names), |
| ) |
|
|
| |
| data_collator = DataCollatorForLanguageModeling( |
| tokenizer=tokenizer, mlm=True, mlm_probability=0.15 |
| ) |
|
|
| |
| accuracy_metric = load("accuracy") |
| def compute_metrics(eval_pred): |
| logits, labels = eval_pred.predictions, eval_pred.label_ids |
| logits = logits.reshape(-1, logits.shape[-1]) |
| labels = labels.reshape(-1) |
| mask = labels != -100 |
|
|
| preds = np.argmax(logits, axis=-1) |
| acc = accuracy_metric.compute( |
| predictions=preds[mask], references=labels[mask] |
| )["accuracy"] |
|
|
| max_logits = np.max(logits[mask], axis=-1, keepdims=True) |
| stable = logits[mask] - max_logits |
| logsumexp = max_logits.flatten() + np.log(np.exp(stable).sum(axis=-1)) |
| true_logits = logits[mask, labels[mask]] |
| xent = -np.mean(true_logits - logsumexp) |
| ppl = float(np.exp(xent)) |
|
|
| return {"accuracy": acc, "perplexity": ppl} |
|
|
| |
| model = BertForMaskedLM(config) |
|
|
| |
| |
| |
| |
| |
|
|
| training_args = TrainingArguments( |
| output_dir = "pl-bert", |
| overwrite_output_dir = True, |
| num_train_epochs = 20, |
| per_device_train_batch_size = 196, |
| per_device_eval_batch_size = 196, |
| warmup_steps = 400, |
| learning_rate = 1e-5, |
| weight_decay = 0.001, |
| eval_strategy = "epoch", |
| save_strategy = "epoch", |
| load_best_model_at_end = True, |
| metric_for_best_model = "perplexity", |
| greater_is_better = False, |
| logging_strategy = "steps", |
| logging_steps = 25, |
| save_total_limit = 3, |
| push_to_hub = False, |
| eval_accumulation_steps = 1, |
| |
| ) |
|
|
| trainer = Trainer( |
| model = model, |
| args = training_args, |
| data_collator = data_collator, |
| train_dataset = lm_train, |
| eval_dataset = lm_eval, |
| compute_metrics = compute_metrics, |
| ) |
|
|
| |
| trainer.train() |
|
|
| best_ckpt = trainer.state.best_model_checkpoint |
| print(f"Best checkpoint directory: {best_ckpt}") |
|
|
| |
| best_model = BertForMaskedLM.from_pretrained(best_ckpt, config=config) |
| torch.save(best_model.state_dict(), "pl-bert-best.pt") |
| print("[✔] Saved best model weights to pl-bert-best.pt") |
|
|
| |
| best_model.save_pretrained("pl-bert-final") |
| tokenizer.save_pretrained("pl-bert-final") |
|
|