File size: 35,000 Bytes
a4462f5 11632a3 a4462f5 11632a3 a4462f5 11632a3 a4462f5 523af6f 1518c01 523af6f e90fca5 1518c01 523af6f e3d02e2 523af6f 2ff8795 523af6f 2ff8795 523af6f d43946a e3d02e2 523af6f e3d02e2 bac26dd 2ff8795 523af6f 2ff8795 523af6f d43946a e3d02e2 523af6f 0e528c7 523af6f 0e528c7 e3d02e2 523af6f 2ff8795 523af6f 2ff8795 523af6f d43946a e3d02e2 523af6f a4462f5 523af6f a4462f5 6b68a70 a4462f5 6b68a70 a4462f5 6b68a70 a4462f5 1518c01 523af6f 1518c01 11632a3 1518c01 11632a3 1518c01 11632a3 1518c01 523af6f 11632a3 a4462f5 11632a3 523af6f 11632a3 523af6f 11632a3 523af6f 11632a3 523af6f 11632a3 523af6f 11632a3 523af6f 11632a3 523af6f 11632a3 523af6f 11632a3 523af6f 11632a3 523af6f 11632a3 523af6f 11632a3 523af6f a4462f5 1518c01 a4462f5 0e528c7 11632a3 0e528c7 11632a3 0e528c7 a4462f5 0e528c7 a4462f5 11632a3 a4462f5 e90fca5 a4462f5 e90fca5 a4462f5 e90fca5 a4462f5 e90fca5 a4462f5 830de2e a4462f5 830de2e a4462f5 e90fca5 b676fac e90fca5 b676fac e90fca5 b676fac e90fca5 b676fac e90fca5 b676fac e90fca5 0e528c7 11632a3 e90fca5 a4462f5 6b68a70 a4462f5 652c046 11632a3 652c046 11632a3 652c046 a4462f5 831405e a4462f5 11632a3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 | """
AramT5 Curriculum Learning Trainer
Features:
- Curriculum learning: short → long sequences
- Catastrophic forgetting mitigation: mixes short examples into later stages
- Character Error Rate (CER) evaluation for transliteration quality
- Early stopping based on validation loss improvement threshold
"""
import argparse
import subprocess
import sys
from pathlib import Path
import numpy as np
import torch
from datasets import concatenate_datasets, load_dataset
from transformers import (DataCollatorForSeq2Seq, EarlyStoppingCallback,
Seq2SeqTrainer, Seq2SeqTrainingArguments, T5Config,
T5ForConditionalGeneration, T5TokenizerFast)
# =============================================================================
# Configuration
# =============================================================================
# Resolve paths relative to project root (parent of src/)
_PROJECT_ROOT = Path(__file__).resolve().parent.parent
# Default paths (relative to project root)
# Use balanced corpus: 40% single, 30% two-word, 30% multi-word
# (augmented corpus was 98.5% single, causing truncated multi-word outputs)
DEFAULT_WEST_DATA = str(_PROJECT_ROOT / "src/data/syriac_west_balanced_corpus.jsonl")
DEFAULT_EAST_DATA = str(_PROJECT_ROOT / "src/data/syriac_east_balanced_corpus.jsonl")
# Source files for balancing (input to balance_corpus.py)
AUGMENTED_WEST_DATA = _PROJECT_ROOT / "src/data/syriac_west_augmented_corpus.jsonl"
AUGMENTED_EAST_DATA = _PROJECT_ROOT / "src/data/syriac_east_augmented_corpus.jsonl"
# Source files for augmentation (input to augment_atomic_tokens.py)
CLEAN_WEST_DATA = _PROJECT_ROOT / "src/data/syriac_west_clean_corpus.jsonl"
CLEAN_EAST_DATA = _PROJECT_ROOT / "src/data/syriac_east_clean_corpus.jsonl"
DEFAULT_TOKENISER = str(_PROJECT_ROOT / "src/tokeniser")
DEFAULT_OUTPUT_DIR = str(_PROJECT_ROOT / "checkpoints")
def ensure_augmented_corpus():
"""
Ensure augmented corpus files exist.
If augmented corpus is missing or older than clean corpus,
regenerate it by running augment_atomic_tokens.py.
"""
needs_augment = False
# Check if augmented files exist
if not AUGMENTED_WEST_DATA.exists() or not AUGMENTED_EAST_DATA.exists():
print("Augmented corpus not found, will generate...")
needs_augment = True
else:
# Check if clean files are newer (source changed)
if CLEAN_WEST_DATA.exists():
if CLEAN_WEST_DATA.stat().st_mtime > AUGMENTED_WEST_DATA.stat().st_mtime:
print("Clean corpus is newer than augmented corpus, regenerating...")
needs_augment = True
if CLEAN_EAST_DATA.exists():
if CLEAN_EAST_DATA.stat().st_mtime > AUGMENTED_EAST_DATA.stat().st_mtime:
print("Clean corpus is newer than augmented corpus, regenerating...")
needs_augment = True
if needs_augment:
augment_script = _PROJECT_ROOT / "src/data/augment_atomic_tokens.py"
if not augment_script.exists():
raise FileNotFoundError(
f"Cannot regenerate augmented corpus: {augment_script} not found"
)
print("Running augment_atomic_tokens.py to generate augmented training data...")
result = subprocess.run(
[sys.executable, str(augment_script)],
cwd=str(_PROJECT_ROOT),
capture_output=True,
text=True,
)
if result.returncode != 0:
print(f"Error running augment_atomic_tokens.py:\n{result.stderr}")
raise RuntimeError("Failed to generate augmented corpus")
print(result.stdout)
print("Augmented corpus generated successfully.")
else:
print("Augmented corpus is up-to-date.")
def ensure_balanced_corpus():
"""
Ensure balanced corpus files exist.
Pipeline: clean_corpus -> augmented_corpus -> balanced_corpus
If balanced corpus is missing or older than augmented corpus,
regenerate it by running balance_corpus.py.
"""
# First ensure augmented corpus exists (upstream dependency)
ensure_augmented_corpus()
west_balanced = Path(DEFAULT_WEST_DATA)
east_balanced = Path(DEFAULT_EAST_DATA)
needs_rebalance = False
# Check if balanced files exist
if not west_balanced.exists() or not east_balanced.exists():
print("Balanced corpus not found, will generate...")
needs_rebalance = True
else:
# Check if augmented files are newer (source changed)
if AUGMENTED_WEST_DATA.exists():
if AUGMENTED_WEST_DATA.stat().st_mtime > west_balanced.stat().st_mtime:
print("Augmented corpus is newer than balanced corpus, regenerating...")
needs_rebalance = True
if AUGMENTED_EAST_DATA.exists():
if AUGMENTED_EAST_DATA.stat().st_mtime > east_balanced.stat().st_mtime:
print("Augmented corpus is newer than balanced corpus, regenerating...")
needs_rebalance = True
if needs_rebalance:
balance_script = _PROJECT_ROOT / "src/data/balance_corpus.py"
if not balance_script.exists():
raise FileNotFoundError(
f"Cannot regenerate balanced corpus: {balance_script} not found"
)
print("Running balance_corpus.py to generate balanced training data...")
result = subprocess.run(
[sys.executable, str(balance_script)],
cwd=str(_PROJECT_ROOT),
capture_output=True,
text=True,
)
if result.returncode != 0:
print(f"Error running balance_corpus.py:\n{result.stderr}")
raise RuntimeError("Failed to generate balanced corpus")
print(result.stdout)
print("Balanced corpus generated successfully.")
else:
print("Balanced corpus is up-to-date.")
# Curriculum learning stage configurations
STAGE_CONFIGS = {
1: {
"description": "Baseline: short sequences only",
"num_samples": 20_000,
"max_src_length": 15, # Characters in source (short words)
"short_mix_ratio": 0.0, # No mixing needed in stage 1
"num_epochs": 30,
"learning_rate": 3e-4,
},
2: {
"description": "Expansion: short phrases",
"num_samples": 40_000,
"max_src_length": 30,
"short_mix_ratio": 0.12, # 12% short examples from previous stages
"short_threshold": 15, # ≤15 chars (Stage 1)
"new_range_ratio": 0.50, # 50% from new range (16-30 chars)
"new_range_min": 16,
"num_epochs": 20,
"learning_rate": 1.2e-4,
},
3: {
"description": "Expansion: medium phrases",
"num_samples": 60_000,
"max_src_length": 50,
"short_mix_ratio": 0.12, # 12% short examples from previous stages
"short_threshold": 30, # ≤30 chars (Stage 1+2)
"new_range_ratio": 0.50, # 50% from new range (31-50 chars)
"new_range_min": 31,
"num_epochs": 20,
"learning_rate": 1e-4,
},
4: {
"description": "Extension: longer phrases",
"num_samples": 120_000, # Increased to better learn multi-word patterns
"max_src_length": 70,
"short_mix_ratio": 0.18, # 18% short examples from previous stages (boosted for retention)
"short_threshold": 50, # ≤50 chars (Stage 1+2+3)
"new_range_ratio": 0.45, # 45% from new range (51-70 chars)
"new_range_min": 51,
"num_epochs": 10,
"learning_rate": 8e-5, # Higher LR to unlearn early-stopping bias from imbalanced data
},
5: {
"description": "Extension: longer sentences",
"num_samples": 150_000, # Increased to better learn multi-word patterns
"max_src_length": 100,
"short_mix_ratio": 0.18, # 18% short examples from previous stages (boosted for retention)
"short_threshold": 70, # ≤70 chars (Stage 1+2+3+4)
"new_range_ratio": 0.45, # 45% from new range (71-100 chars)
"new_range_min": 71,
"num_epochs": 10,
"learning_rate": 5e-5, # Slightly higher to reinforce multi-word patterns
"repetition_penalty": 1.2,
},
6: {
"description": "Full practical corpus: sentences and short paragraphs",
"num_samples": 180_000, # Increased to better learn multi-word patterns
"max_src_length": 150,
"short_mix_ratio": 0.20, # 20% short examples from previous stages (highest retention)
"short_threshold": 100, # ≤100 chars (Stage 1+2+3+4+5)
"new_range_ratio": 0.40, # 40% from new range (101-150 chars)
"new_range_min": 101,
"num_epochs": 10,
"learning_rate": 4e-5, # Fine-tuning polish
"repetition_penalty": 1.2,
},
}
# Early stopping config
EARLY_STOPPING_PATIENCE = 3
EARLY_STOPPING_THRESHOLD = 0.005 # 0.5% improvement threshold
def parse_args():
parser = argparse.ArgumentParser(description="AramT5 Curriculum Learning Trainer")
parser.add_argument(
"--stage",
type=int,
default=1,
choices=[1, 2, 3, 4, 5, 6],
help="Training stage (1=baseline, 2=medium-long, 3=expansion, 4=extension, 5=longer sentences, 6=full practical)",
)
parser.add_argument(
"--hf-model",
type=str,
default=None,
help="HuggingFace model ID to fine-tune (required for stage 2+)",
)
parser.add_argument(
"--west-data",
type=str,
default=DEFAULT_WEST_DATA,
help="Path to West Syriac corpus",
)
parser.add_argument(
"--east-data",
type=str,
default=DEFAULT_EAST_DATA,
help="Path to East Syriac corpus",
)
parser.add_argument(
"--tokeniser",
type=str,
default=DEFAULT_TOKENISER,
help="Path to tokeniser",
)
parser.add_argument(
"--output-dir",
type=str,
default=DEFAULT_OUTPUT_DIR,
help="Output directory for checkpoints",
)
parser.add_argument(
"--batch-size",
type=int,
default=16,
help="Per-device batch size",
)
parser.add_argument(
"--no-early-stopping",
action="store_true",
help="Disable early stopping",
)
parser.add_argument(
"--resume",
type=str,
nargs="?",
const="auto",
default=None,
help="Resume from checkpoint. Use --resume for auto-detect or --resume path/to/checkpoint",
)
return parser.parse_args()
# =============================================================================
# Model Loading
# =============================================================================
def load_model_and_tokeniser(
stage: int = 1,
hf_model: str | None = None,
tokeniser_path: str = DEFAULT_TOKENISER,
):
"""
Load model and tokeniser based on training stage.
Args:
stage: Training stage (1=baseline, 2+=fine-tune from HF)
hf_model: HuggingFace model ID (required for stage 2+)
tokeniser_path: Path to local tokeniser directory
Returns:
Tuple of (model, tokeniser)
"""
tokeniser = T5TokenizerFast.from_pretrained(tokeniser_path)
vocab_size = tokeniser.vocab_size
pad_token_id = tokeniser.pad_token_id
if stage == 1:
# Stage 1: Initialise from scratch with custom config
print("Stage 1: Initialising new model from scratch...")
config = T5Config(
vocab_size=vocab_size,
d_model=512,
d_ff=2048,
num_layers=6,
num_heads=8,
pad_token_id=pad_token_id,
decoder_start_token_id=pad_token_id,
tie_word_embeddings=True,
)
model = T5ForConditionalGeneration(config)
else:
# Stage 2+: Load from HuggingFace
if not hf_model:
raise ValueError(f"Stage {stage} requires --hf-model argument")
print(f"Stage {stage}: Loading model from HuggingFace: {hf_model}")
model = T5ForConditionalGeneration.from_pretrained(hf_model)
return model, tokeniser
# =============================================================================
# Data Processing
# =============================================================================
def get_src_length(example):
"""Extract source text length for curriculum sorting."""
return len(example["transliteration"]["src"])
def create_tokenise_function(tokeniser):
"""Create tokenisation function with closure over tokeniser."""
pad_token_id = tokeniser.pad_token_id
def tokenise_function(example: dict) -> dict:
"""
Tokenise input data with dialect-aware task prefix.
Task prefixes:
- "Syriac2WestLatin: " for West Syriac (Serto)
- "Syriac2EastLatin: " for East Syriac (Madnḥaya)
"""
inputs = []
targets = []
for item in example["transliteration"]:
dialect = item.get("dialect", "west")
if dialect == "east":
prefix = "Syriac2EastLatin: "
else:
prefix = "Syriac2WestLatin: "
inputs.append(f"{prefix}{item['src']}")
targets.append(item["tgt"])
model_inputs = tokeniser(
inputs, max_length=256, truncation=True, padding="max_length"
)
labels = tokeniser(
targets, max_length=256, truncation=True, padding="max_length"
)["input_ids"]
# Replace padding token id with -100 so it's ignored in loss computation
labels = [
[(token if token != pad_token_id else -100) for token in label]
for label in labels
]
model_inputs["labels"] = labels
return model_inputs
return tokenise_function
def load_and_prepare_data(
stage_config: dict,
stage: int = 1,
west_data: str = DEFAULT_WEST_DATA,
east_data: str = DEFAULT_EAST_DATA,
):
"""
Load and prepare data according to curriculum learning stage.
Args:
stage_config: Configuration dict for the current stage
stage: Training stage number (for logging and mixing logic)
west_data: Path to West Syriac corpus JSONL file
east_data: Path to East Syriac corpus JSONL file
Returns:
Tuple of (train_dataset, val_dataset) filtered by sequence length.
"""
print(f"\n{'=' * 60}")
print(f"Stage {stage}: {stage_config['description']}")
print(f"{'=' * 60}\n")
# Load both dialect corpora
print("Loading West Syriac corpus...")
west_dataset = load_dataset("json", data_files=west_data, split="train")
print(f" Loaded {len(west_dataset)} examples")
print("Loading East Syriac corpus...")
east_dataset = load_dataset("json", data_files=east_data, split="train")
print(f" Loaded {len(east_dataset)} examples")
# Combine datasets
full_dataset = concatenate_datasets([west_dataset, east_dataset])
print(f"Total combined: {len(full_dataset)} examples")
# Add source length column for filtering/sorting
full_dataset = full_dataset.map(
lambda x: {"src_length": get_src_length(x)}, num_proc=4
)
# Sort by length (curriculum: short → long)
full_dataset = full_dataset.sort("src_length")
# Apply length filter if specified
max_len = stage_config["max_src_length"]
if max_len is not None:
print(f"\nFiltering to sequences with src_length <= {max_len} characters...")
filtered_dataset = full_dataset.filter(lambda x: x["src_length"] <= max_len)
print(f" After filtering: {len(filtered_dataset)} examples")
else:
filtered_dataset = full_dataset
print("\nNo length filter applied (using all sequences)")
# Sample to target size
num_samples = min(stage_config["num_samples"], len(filtered_dataset))
print(f"\nSampling {num_samples} examples for training...")
# For stages 2+, mix in some short examples to prevent catastrophic forgetting
short_mix_ratio = stage_config["short_mix_ratio"]
middle_oversample = stage_config.get("middle_oversample", False)
if middle_oversample:
# Stage 4 special handling: oversample the rare 15-100 char range
# to build bridge competence before full corpus
num_short = int(num_samples * short_mix_ratio)
num_middle = int(num_samples * 0.40) # 40% from middle range (15-100)
num_main = num_samples - num_short - num_middle
# Short examples (≤15 chars = Stage 1 range) for forgetting mitigation
short_threshold = 15
short_examples = full_dataset.filter(
lambda x: x["src_length"] <= short_threshold
)
short_examples = short_examples.shuffle(seed=42).select(
range(min(num_short, len(short_examples)))
)
print(f" Short examples (≤{short_threshold} chars): {len(short_examples)}")
# Middle-range examples (15-100 chars) - oversample these rare sequences
middle_examples = filtered_dataset.filter(lambda x: 15 < x["src_length"] <= 100)
# Repeat/oversample if needed since these are scarce
if len(middle_examples) < num_middle:
# Repeat the middle examples to reach target
repeats_needed = (num_middle // len(middle_examples)) + 1
middle_repeated = concatenate_datasets([middle_examples] * repeats_needed)
middle_examples = middle_repeated.shuffle(seed=42).select(range(num_middle))
print(
f" Middle-range examples (15-100 chars, oversampled): {len(middle_examples)}"
)
else:
middle_examples = middle_examples.shuffle(seed=42).select(range(num_middle))
print(f" Middle-range examples (15-100 chars): {len(middle_examples)}")
# Main examples from full filtered set
main_examples = filtered_dataset.shuffle(seed=43).select(
range(min(num_main, len(filtered_dataset)))
)
print(f" Main examples: {len(main_examples)}")
# Combine and shuffle
sampled_dataset = concatenate_datasets(
[short_examples, middle_examples, main_examples]
)
sampled_dataset = sampled_dataset.shuffle(seed=42)
elif short_mix_ratio > 0 and stage > 1:
# Stratified sampling: ensure we get examples from the NEW length range
new_range_ratio = stage_config.get("new_range_ratio", 0)
new_range_min = stage_config.get("new_range_min", 0)
num_short = int(num_samples * short_mix_ratio)
if new_range_ratio > 0 and new_range_min > 0:
# Stratified: short + new_range + remainder
num_new_range = int(num_samples * new_range_ratio)
num_remainder = num_samples - num_short - num_new_range
# Short examples = everything from previous stages (for forgetting mitigation)
short_threshold = stage_config.get("short_threshold", 15)
short_examples = full_dataset.filter(
lambda x, thresh=short_threshold: x["src_length"] <= thresh
)
short_examples = short_examples.shuffle(seed=42).select(
range(min(num_short, len(short_examples)))
)
print(
f" Short examples (≤{short_threshold} chars, previous stages): {len(short_examples)}"
)
# New range examples - these are what the model needs to learn
new_range_examples = filtered_dataset.filter(
lambda x, min_len=new_range_min: x["src_length"] >= min_len
)
print(
f" New range pool ({new_range_min}-{max_len} chars): {len(new_range_examples)} available"
)
# Oversample if needed (these are scarce!)
if len(new_range_examples) < num_new_range:
if len(new_range_examples) > 0:
repeats_needed = (num_new_range // len(new_range_examples)) + 1
new_range_repeated = concatenate_datasets(
[new_range_examples] * repeats_needed
)
new_range_examples = new_range_repeated.shuffle(seed=42).select(
range(num_new_range)
)
print(
f" New range examples (oversampled {repeats_needed}x): {len(new_range_examples)}"
)
else:
print(f" WARNING: No examples in new range!")
new_range_examples = full_dataset.filter(lambda x: False) # empty
else:
new_range_examples = new_range_examples.shuffle(seed=42).select(
range(num_new_range)
)
print(f" New range examples: {len(new_range_examples)}")
# Remainder from full filtered set (includes all lengths up to max)
remainder_examples = filtered_dataset.shuffle(seed=43).select(
range(min(num_remainder, len(filtered_dataset)))
)
print(f" Remainder examples: {len(remainder_examples)}")
# Combine and shuffle
sampled_dataset = concatenate_datasets(
[short_examples, new_range_examples, remainder_examples]
)
sampled_dataset = sampled_dataset.shuffle(seed=42)
else:
# Original logic: just short + main
num_main = num_samples - num_short
# Get short examples = everything from previous stages
short_threshold = stage_config.get("short_threshold", 15)
short_examples = full_dataset.filter(
lambda x, thresh=short_threshold: x["src_length"] <= thresh
)
short_examples = short_examples.shuffle(seed=42).select(
range(min(num_short, len(short_examples)))
)
print(
f" Short examples (≤{short_threshold} chars, previous stages): {len(short_examples)}"
)
# Get main examples from filtered dataset
# Apply minimum length filter for main examples in later stages
min_len = stage_config.get("min_src_length", 0)
if min_len > 0:
main_pool = filtered_dataset.filter(
lambda x: x["src_length"] >= min_len
)
print(
f" Main pool after min_length={min_len} filter: {len(main_pool)} examples"
)
else:
main_pool = filtered_dataset
main_examples = main_pool.shuffle(seed=42).select(
range(min(num_main, len(main_pool)))
)
print(f" Main examples: {len(main_examples)}")
# Combine and shuffle
sampled_dataset = concatenate_datasets([short_examples, main_examples])
sampled_dataset = sampled_dataset.shuffle(seed=42)
else:
sampled_dataset = filtered_dataset.shuffle(seed=42).select(range(num_samples))
print(f" Final training pool: {len(sampled_dataset)} examples")
# Split into train/validation (90/10 for stages 4-5, 80/20 for earlier stages)
val_ratio = 0.1 if stage >= 4 else 0.2
dataset_split = sampled_dataset.train_test_split(test_size=val_ratio, seed=42)
train_dataset = dataset_split["train"]
val_dataset = dataset_split["test"]
print(f"\nTrain set: {len(train_dataset)} examples")
print(f"Validation set: {len(val_dataset)} examples")
# Report length statistics
train_lengths = train_dataset["src_length"]
print(f"\nSource length statistics (train):")
print(f" Min: {min(train_lengths)}, Max: {max(train_lengths)}")
print(
f" Mean: {np.mean(train_lengths):.1f}, Median: {np.median(train_lengths):.1f}"
)
return train_dataset, val_dataset
# =============================================================================
# Evaluation Metrics
# =============================================================================
def compute_cer(pred_str: str, target_str: str) -> float:
"""
Compute Character Error Rate between prediction and target.
CER = (substitutions + insertions + deletions) / len(target)
Uses edit distance (Levenshtein distance).
"""
if len(target_str) == 0:
return 0.0 if len(pred_str) == 0 else 1.0
# Simple Levenshtein distance implementation
m, n = len(pred_str), len(target_str)
dp = [[0] * (n + 1) for _ in range(m + 1)]
for i in range(m + 1):
dp[i][0] = i
for j in range(n + 1):
dp[0][j] = j
for i in range(1, m + 1):
for j in range(1, n + 1):
if pred_str[i - 1] == target_str[j - 1]:
dp[i][j] = dp[i - 1][j - 1]
else:
dp[i][j] = 1 + min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1])
return dp[m][n] / len(target_str)
def create_compute_metrics(tokeniser):
"""Create metrics computation function with closure over tokeniser."""
sample_count = [0] # Mutable counter for periodic logging
def compute_metrics(eval_preds):
"""Compute CER and exact match accuracy for evaluation."""
preds, labels = eval_preds
# Replace -100 with pad token for decoding
labels = np.where(labels != -100, labels, tokeniser.pad_token_id)
preds = np.where(preds != -100, preds, tokeniser.pad_token_id)
# Decode to strings
pred_strs = tokeniser.batch_decode(preds, skip_special_tokens=True)
label_strs = tokeniser.batch_decode(labels, skip_special_tokens=True)
# Log sample predictions periodically for debugging
sample_count[0] += 1
if sample_count[0] % 2 == 1: # Every other eval
print("\n--- Sample predictions (first 5) ---")
for i in range(min(5, len(pred_strs))):
print(f" Target: '{label_strs[i]}'")
print(f" Pred: '{pred_strs[i]}'")
print(f" CER: {compute_cer(pred_strs[i], label_strs[i]):.3f}")
print()
# Compute metrics
cer_scores = [
compute_cer(pred, target) for pred, target in zip(pred_strs, label_strs)
]
exact_matches = [
1.0 if pred.strip() == target.strip() else 0.0
for pred, target in zip(pred_strs, label_strs)
]
# Log length statistics
pred_lens = [len(p) for p in pred_strs]
label_lens = [len(l) for l in label_strs]
print(
f" Avg pred len: {np.mean(pred_lens):.1f}, Avg label len: {np.mean(label_lens):.1f}"
)
# Compute length ratio penalty (penalise under-generation)
# Ratio < 1 means output is shorter than target
length_ratios = [
len(pred) / max(len(target), 1)
for pred, target in zip(pred_strs, label_strs)
]
# Penalty: how much shorter outputs are on average (0 = perfect, higher = worse)
# Only penalise under-generation (ratio < 1), not over-generation
length_penalties = [max(0, 1 - ratio) for ratio in length_ratios]
avg_length_penalty = np.mean(length_penalties)
avg_length_ratio = np.mean(length_ratios)
print(
f" Avg length ratio: {avg_length_ratio:.3f}, Avg length penalty: {avg_length_penalty:.3f}"
)
return {
"cer": np.mean(cer_scores),
"exact_match": np.mean(exact_matches),
"length_ratio": avg_length_ratio,
"length_penalty": avg_length_penalty,
}
return compute_metrics
# =============================================================================
# Training
# =============================================================================
def train(args):
"""Main training function implementing curriculum learning."""
# Ensure balanced corpus exists (auto-regenerate if needed)
ensure_balanced_corpus()
stage_config = STAGE_CONFIGS[args.stage]
# Load model and tokeniser
model, tokeniser = load_model_and_tokeniser(
stage=args.stage,
hf_model=args.hf_model,
tokeniser_path=args.tokeniser,
)
# Enable gradient checkpointing to save memory
model.gradient_checkpointing_enable()
# Load and prepare data
train_dataset, val_dataset = load_and_prepare_data(
stage_config=stage_config,
stage=args.stage,
west_data=args.west_data,
east_data=args.east_data,
)
# Tokenise datasets
tokenise_fn = create_tokenise_function(tokeniser)
tokenised_train = train_dataset.map(
tokenise_fn,
batched=True,
remove_columns=train_dataset.column_names,
desc="Tokenising train set",
)
tokenised_eval = val_dataset.map(
tokenise_fn,
batched=True,
remove_columns=val_dataset.column_names,
desc="Tokenising eval set",
)
# Data collator
data_collator = DataCollatorForSeq2Seq(tokenizer=tokeniser, model=model)
# Training arguments
# Stage-specific hyperparameters for better early learning
grad_accum = 4 if args.stage <= 2 else 8 # Smaller effective batch for early stages
label_smooth = 0.05 if args.stage == 1 else (0.08 if args.stage <= 3 else 0.1)
warmup = 0.10 if args.stage == 1 else 0.06 # More warmup for training from scratch
training_args = Seq2SeqTrainingArguments(
output_dir=args.output_dir,
overwrite_output_dir=True,
num_train_epochs=stage_config["num_epochs"],
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
gradient_accumulation_steps=grad_accum,
learning_rate=stage_config["learning_rate"],
warmup_ratio=warmup,
weight_decay=0.01,
label_smoothing_factor=label_smooth,
save_strategy="epoch",
save_total_limit=3,
eval_strategy="epoch",
logging_dir="logs",
fp16=torch.cuda.is_available(),
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
report_to="none",
predict_with_generate=True,
generation_max_length=256, # Generous headroom for all stages
)
# Configure generation settings
# max_length is total sequence length - set high to avoid truncation
model.generation_config.max_length = 256
# Don't use no_repeat_ngram_size - it blocks valid Syriac patterns
# Don't use repetition_penalty - transliteration has legitimate repetition
model.generation_config.eos_token_id = tokeniser.eos_token_id
model.generation_config.pad_token_id = tokeniser.pad_token_id
# Minimum length and length_penalty to discourage under-generation
# Applied to ALL stages with progressive values
model.generation_config.min_length = 2 if args.stage < 5 else 3
# Use beam search with length_penalty to encourage full-length outputs
# Progressive beam size and length penalty by stage
# Increased penalties to counter systematic under-generation (~8% shorter outputs)
if args.stage == 1:
model.generation_config.num_beams = 2
model.generation_config.length_penalty = 1.05 # Slight encouragement from start
elif args.stage == 2:
model.generation_config.num_beams = 2
model.generation_config.length_penalty = 1.12 # Counter under-generation
elif args.stage == 3:
model.generation_config.num_beams = 3
model.generation_config.length_penalty = 1.18
elif args.stage == 4:
model.generation_config.num_beams = 4
model.generation_config.length_penalty = 1.22
else: # stages 5-6
model.generation_config.num_beams = 4
model.generation_config.length_penalty = (
1.25 # >1.0 encourages longer sequences
)
model.generation_config.early_stopping = True
# Callbacks
callbacks = []
if not args.no_early_stopping:
callbacks.append(
EarlyStoppingCallback(
early_stopping_patience=EARLY_STOPPING_PATIENCE,
early_stopping_threshold=EARLY_STOPPING_THRESHOLD,
)
)
print(f"\nEarly stopping enabled:")
print(f" Patience: {EARLY_STOPPING_PATIENCE} evaluations")
print(f" Threshold: {EARLY_STOPPING_THRESHOLD * 100:.1f}% improvement")
# Trainer
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=tokenised_train,
eval_dataset=tokenised_eval,
data_collator=data_collator,
processing_class=tokeniser,
compute_metrics=create_compute_metrics(tokeniser),
callbacks=callbacks,
)
# Train
print(f"\n{'=' * 60}")
print("Starting training...")
print(f"{'=' * 60}\n")
# Handle checkpoint resumption
resume_from_checkpoint = None
if args.resume:
if args.resume == "auto":
# Auto-detect: let Trainer find the last checkpoint
resume_from_checkpoint = True
print("Resuming from last checkpoint (auto-detect)...")
else:
# Specific checkpoint path provided
resume_from_checkpoint = args.resume
print(f"Resuming from checkpoint: {args.resume}")
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
# Save final model
final_output_dir = f"{args.output_dir}/stage{args.stage}-final"
model.save_pretrained(final_output_dir)
tokeniser.save_pretrained(final_output_dir)
print(f"\n{'=' * 60}")
print(f"Stage {args.stage} training complete!")
print(f"Model saved to: {final_output_dir}")
print(f"{'=' * 60}")
# Update README metrics
try:
from update_readme_metrics import (extract_metrics,
find_best_checkpoint,
update_readme_metrics)
checkpoint_dir = find_best_checkpoint(Path(args.output_dir))
if checkpoint_dir:
metrics = extract_metrics(checkpoint_dir)
readme_path = _PROJECT_ROOT / "README.md"
if update_readme_metrics(readme_path, metrics):
print(
f"\nREADME metrics updated automatically from {checkpoint_dir.name}"
)
except Exception as e:
print(f"\nNote: Could not auto-update README metrics: {e}")
# Print next steps
if args.stage <= 6:
print(f"\nNext steps:")
print(
f" 1. Upload model to HuggingFace (e.g., 'your-username/aramt5-v{args.stage}')"
)
print(f" 2. Run stage {args.stage + 1}:")
print(
f" python src/train_t5.py --stage {args.stage + 1} --hf-model your-username/aramt5-v{args.stage}"
)
if __name__ == "__main__":
args = parse_args()
train(args)
|