PL-BERT-HE / train.py
oronkam123's picture
Upload folder using huggingface_hub
eabd695 verified
#!/usr/bin/env python3
# pl-bert_training.py
from datasets import load_dataset
from transformers import (
BertTokenizerFast,
BertConfig,
BertForMaskedLM,
DataCollatorForLanguageModeling,
Trainer,
TrainingArguments
)
from evaluate import load # use the 'evaluate' library for metrics
import torch
import yaml
import numpy as np
# # 1. Load dataset and split
# full_ds = load_dataset("thewh1teagle/phonikud-phonemes-data", split="train[:5000000]")
# ds_train = full_ds.select(range(0, 4700000)) # first 4.7M examples
# ds_eval = full_ds.select(range(4700000, 5000000)) # last 300k examples
# 2. Split "text" column into Hebrew and phonemes
def split_tab(examples):
heb, phon = [], []
for line in examples["text"]:
h, p = line.split("\t")
heb.append(h)
phon.append(p)
return {"hebrew": heb, "phonemes": phon}
ds_train = ds_train.map(split_tab, batched=True, remove_columns=["text"])
ds_eval = ds_eval.map(split_tab, batched=True, remove_columns=["text"])
# 3. Build character‐level phoneme vocab
punctuation = ';:,.!?¡¿—…"«»“” ’'
letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
letters_ipa = (
"ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯ혂ŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘̩ᵻ"
)
specials = ["[PAD]","[UNK]","[CLS]","[SEP]","[MASK]"]
all_chars = list(punctuation) + list(letters) + list(letters_ipa)
vocab_chars = []
seen = set()
for c in all_chars:
if c not in seen:
seen.add(c)
vocab_chars.append(c)
with open("vocab.txt", "w", encoding="utf-8") as vf:
for tok in specials:
vf.write(tok + "\n")
for c in vocab_chars:
vf.write(c + "\n")
# 4. Initialize tokenizer & model config
tokenizer = BertTokenizerFast(
vocab_file="vocab.txt",
unk_token="[UNK]", pad_token="[PAD]",
cls_token="[CLS]", sep_token="[SEP]",
mask_token="[MASK]",
do_lower_case=False,
strip_accents=False,
tokenize_chinese_chars=False,
)
config = BertConfig(
vocab_size = len(tokenizer),
hidden_size = 768,
num_hidden_layers = 12,
num_attention_heads = 12,
intermediate_size = 2048,
max_position_embeddings = 512,
hidden_dropout_prob = 0.1,
attention_probs_dropout_prob= 0.1,
)
# 5. Tokenize only the phoneme sequences
def tokenize_fn(examples):
return tokenizer(
examples["phonemes"],
return_attention_mask=True,
add_special_tokens=True,
)
tokenized_train = ds_train.map(
tokenize_fn,
batched=True,
remove_columns=["hebrew","phonemes"]
)
tokenized_eval = ds_eval.map(
tokenize_fn,
batched=True,
remove_columns=["hebrew","phonemes"]
)
# 6. Group into fixed-length blocks for MLM
block_size = 128
def group_texts(examples):
all_ids = sum(examples["input_ids"], [])
result = {"input_ids":[], "attention_mask":[]}
for i in range(0, len(all_ids) - block_size + 1, block_size):
chunk = all_ids[i : i + block_size]
result["input_ids"].append(chunk)
result["attention_mask"].append([1] * block_size)
return result
lm_train = tokenized_train.map(
group_texts,
batched=True,
remove_columns=list(tokenized_train.column_names),
)
lm_eval = tokenized_eval.map(
group_texts,
batched=True,
remove_columns=list(tokenized_eval.column_names),
)
# 7. Data collator for MLM
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer, mlm=True, mlm_probability=0.15
)
# 8. Metrics for accuracy + perplexity
accuracy_metric = load("accuracy")
def compute_metrics(eval_pred):
logits, labels = eval_pred.predictions, eval_pred.label_ids
logits = logits.reshape(-1, logits.shape[-1])
labels = labels.reshape(-1)
mask = labels != -100
preds = np.argmax(logits, axis=-1)
acc = accuracy_metric.compute(
predictions=preds[mask], references=labels[mask]
)["accuracy"]
max_logits = np.max(logits[mask], axis=-1, keepdims=True)
stable = logits[mask] - max_logits
logsumexp = max_logits.flatten() + np.log(np.exp(stable).sum(axis=-1))
true_logits = logits[mask, labels[mask]]
xent = -np.mean(true_logits - logsumexp)
ppl = float(np.exp(xent))
return {"accuracy": acc, "perplexity": ppl}
# 9. Initialize model & Trainer
model = BertForMaskedLM(config)
# 9b) Load your best .pt checkpoint into it
# ckpt_path = "/dev/hdd/Users/Oron/tts/pl-bert/pl-bert-best1.pt"
# state_dict = torch.load(ckpt_path, map_location="cpu")
# model.load_state_dict(state_dict, strict=False)
# print(f"[✔] Loaded pretrained PL‑BERT weights from {ckpt_path}")
training_args = TrainingArguments(
output_dir = "pl-bert",
overwrite_output_dir = True,
num_train_epochs = 20,
per_device_train_batch_size = 196,
per_device_eval_batch_size = 196,
warmup_steps = 400,
learning_rate = 1e-5,
weight_decay = 0.001,
eval_strategy = "epoch",
save_strategy = "epoch",
load_best_model_at_end = True,
metric_for_best_model = "perplexity",
greater_is_better = False,
logging_strategy = "steps",
logging_steps = 25,
save_total_limit = 3,
push_to_hub = False,
eval_accumulation_steps = 1,
# fp16=True, # uncomment if you want mixed precision
)
trainer = Trainer(
model = model,
args = training_args,
data_collator = data_collator,
train_dataset = lm_train,
eval_dataset = lm_eval,
compute_metrics = compute_metrics,
)
# 10. Train & save best checkpoint and .pt file
trainer.train()
best_ckpt = trainer.state.best_model_checkpoint
print(f"Best checkpoint directory: {best_ckpt}")
# Load the best checkpoint and save a raw .pt state_dict
best_model = BertForMaskedLM.from_pretrained(best_ckpt, config=config)
torch.save(best_model.state_dict(), "pl-bert-best.pt")
print("[✔] Saved best model weights to pl-bert-best.pt")
# Also keep HF format
best_model.save_pretrained("pl-bert-final")
tokenizer.save_pretrained("pl-bert-final")