aramt5 / src /train_tokeniser.py
crossroderick's picture
Data augmentation and balancing updates for a re-run of v3
11632a3
import json
import os
import sentencepiece as spm
from transformers import T5Tokenizer
# Corpus files for both dialects
CORPUS_FILES = [
"src/data/syriac_west_clean_corpus.jsonl",
"src/data/syriac_east_clean_corpus.jsonl",
]
# Dialect task prefixes (must match train_t5.py)
DIALECT_PREFIXES = {
"west": "Syriac2WestLatin: ",
"east": "Syriac2EastLatin: ",
}
# Load corpus data from both dialects
print("Loading corpus data...")
with open("src/data/tokeniser_corpus.txt", "w", encoding="utf-8") as f_out:
total_records = 0
for corpus_file in CORPUS_FILES:
if not os.path.exists(corpus_file):
print(f"Warning: {corpus_file} not found, skipping...")
continue
print(f"Processing {corpus_file}...")
with open(corpus_file, "r", encoding="utf-8") as f_in:
for i, line in enumerate(f_in):
item = json.loads(line)
dialect = item["transliteration"].get("dialect", "west")
prefix = DIALECT_PREFIXES.get(dialect, DIALECT_PREFIXES["west"])
src = item["transliteration"]["src"]
tgt = item["transliteration"]["tgt"]
# Write with dialect prefix to train tokeniser on task format
f_out.write(f"{prefix}{src}\n")
f_out.write(tgt + "\n")
total_records += 1
print(f"Total records loaded: {total_records}")
# Ensure output directory exists
os.makedirs("src/tokeniser", exist_ok=True)
# Train the sentence piece model
print("Training SentencePiece model...")
spm.SentencePieceTrainer.Train(
input="src/data/tokeniser_corpus.txt",
model_prefix="src/tokeniser/aramt5_sp",
vocab_size=16000, # adjust as needed
model_type="unigram", # worth testing with "bpe"
character_coverage=1.0, # to preserve rare characters
max_sentence_length=8384 * 2, # to handle long sequences
pad_id=0,
unk_id=1,
bos_id=2,
eos_id=3,
user_defined_symbols=[
"<pad>",
"<s>",
"</s>",
"Syriac2WestLatin:", # dialect task prefix (West/Serto)
"Syriac2EastLatin:", # dialect task prefix (East/Madnḥaya)
],
)
# Convert to a HF-compatible format
print("Converting to HuggingFace format...")
tokeniser = T5Tokenizer(vocab_file="src/tokeniser/aramt5_sp.model", legacy=False)
tokeniser.save_pretrained("src/tokeniser/")
print("Tokeniser saved to src/tokeniser/")
print(
"Next step: Run augment_atomic_tokens.py to boost single-token training examples."
)