aramt5 / src /train_t5.py
crossroderick's picture
v3.2 update
d43946a
"""
AramT5 Curriculum Learning Trainer
Features:
- Curriculum learning: short → long sequences
- Catastrophic forgetting mitigation: mixes short examples into later stages
- Character Error Rate (CER) evaluation for transliteration quality
- Early stopping based on validation loss improvement threshold
"""
import argparse
import subprocess
import sys
from pathlib import Path
import numpy as np
import torch
from datasets import concatenate_datasets, load_dataset
from transformers import (DataCollatorForSeq2Seq, EarlyStoppingCallback,
Seq2SeqTrainer, Seq2SeqTrainingArguments, T5Config,
T5ForConditionalGeneration, T5TokenizerFast)
# =============================================================================
# Configuration
# =============================================================================
# Resolve paths relative to project root (parent of src/)
_PROJECT_ROOT = Path(__file__).resolve().parent.parent
# Default paths (relative to project root)
# Use balanced corpus: 40% single, 30% two-word, 30% multi-word
# (augmented corpus was 98.5% single, causing truncated multi-word outputs)
DEFAULT_WEST_DATA = str(_PROJECT_ROOT / "src/data/syriac_west_balanced_corpus.jsonl")
DEFAULT_EAST_DATA = str(_PROJECT_ROOT / "src/data/syriac_east_balanced_corpus.jsonl")
# Source files for balancing (input to balance_corpus.py)
AUGMENTED_WEST_DATA = _PROJECT_ROOT / "src/data/syriac_west_augmented_corpus.jsonl"
AUGMENTED_EAST_DATA = _PROJECT_ROOT / "src/data/syriac_east_augmented_corpus.jsonl"
# Source files for augmentation (input to augment_atomic_tokens.py)
CLEAN_WEST_DATA = _PROJECT_ROOT / "src/data/syriac_west_clean_corpus.jsonl"
CLEAN_EAST_DATA = _PROJECT_ROOT / "src/data/syriac_east_clean_corpus.jsonl"
DEFAULT_TOKENISER = str(_PROJECT_ROOT / "src/tokeniser")
DEFAULT_OUTPUT_DIR = str(_PROJECT_ROOT / "checkpoints")
def ensure_augmented_corpus():
"""
Ensure augmented corpus files exist.
If augmented corpus is missing or older than clean corpus,
regenerate it by running augment_atomic_tokens.py.
"""
needs_augment = False
# Check if augmented files exist
if not AUGMENTED_WEST_DATA.exists() or not AUGMENTED_EAST_DATA.exists():
print("Augmented corpus not found, will generate...")
needs_augment = True
else:
# Check if clean files are newer (source changed)
if CLEAN_WEST_DATA.exists():
if CLEAN_WEST_DATA.stat().st_mtime > AUGMENTED_WEST_DATA.stat().st_mtime:
print("Clean corpus is newer than augmented corpus, regenerating...")
needs_augment = True
if CLEAN_EAST_DATA.exists():
if CLEAN_EAST_DATA.stat().st_mtime > AUGMENTED_EAST_DATA.stat().st_mtime:
print("Clean corpus is newer than augmented corpus, regenerating...")
needs_augment = True
if needs_augment:
augment_script = _PROJECT_ROOT / "src/data/augment_atomic_tokens.py"
if not augment_script.exists():
raise FileNotFoundError(
f"Cannot regenerate augmented corpus: {augment_script} not found"
)
print("Running augment_atomic_tokens.py to generate augmented training data...")
result = subprocess.run(
[sys.executable, str(augment_script)],
cwd=str(_PROJECT_ROOT),
capture_output=True,
text=True,
)
if result.returncode != 0:
print(f"Error running augment_atomic_tokens.py:\n{result.stderr}")
raise RuntimeError("Failed to generate augmented corpus")
print(result.stdout)
print("Augmented corpus generated successfully.")
else:
print("Augmented corpus is up-to-date.")
def ensure_balanced_corpus():
"""
Ensure balanced corpus files exist.
Pipeline: clean_corpus -> augmented_corpus -> balanced_corpus
If balanced corpus is missing or older than augmented corpus,
regenerate it by running balance_corpus.py.
"""
# First ensure augmented corpus exists (upstream dependency)
ensure_augmented_corpus()
west_balanced = Path(DEFAULT_WEST_DATA)
east_balanced = Path(DEFAULT_EAST_DATA)
needs_rebalance = False
# Check if balanced files exist
if not west_balanced.exists() or not east_balanced.exists():
print("Balanced corpus not found, will generate...")
needs_rebalance = True
else:
# Check if augmented files are newer (source changed)
if AUGMENTED_WEST_DATA.exists():
if AUGMENTED_WEST_DATA.stat().st_mtime > west_balanced.stat().st_mtime:
print("Augmented corpus is newer than balanced corpus, regenerating...")
needs_rebalance = True
if AUGMENTED_EAST_DATA.exists():
if AUGMENTED_EAST_DATA.stat().st_mtime > east_balanced.stat().st_mtime:
print("Augmented corpus is newer than balanced corpus, regenerating...")
needs_rebalance = True
if needs_rebalance:
balance_script = _PROJECT_ROOT / "src/data/balance_corpus.py"
if not balance_script.exists():
raise FileNotFoundError(
f"Cannot regenerate balanced corpus: {balance_script} not found"
)
print("Running balance_corpus.py to generate balanced training data...")
result = subprocess.run(
[sys.executable, str(balance_script)],
cwd=str(_PROJECT_ROOT),
capture_output=True,
text=True,
)
if result.returncode != 0:
print(f"Error running balance_corpus.py:\n{result.stderr}")
raise RuntimeError("Failed to generate balanced corpus")
print(result.stdout)
print("Balanced corpus generated successfully.")
else:
print("Balanced corpus is up-to-date.")
# Curriculum learning stage configurations
STAGE_CONFIGS = {
1: {
"description": "Baseline: short sequences only",
"num_samples": 20_000,
"max_src_length": 15, # Characters in source (short words)
"short_mix_ratio": 0.0, # No mixing needed in stage 1
"num_epochs": 30,
"learning_rate": 3e-4,
},
2: {
"description": "Expansion: short phrases",
"num_samples": 40_000,
"max_src_length": 30,
"short_mix_ratio": 0.12, # 12% short examples from previous stages
"short_threshold": 15, # ≤15 chars (Stage 1)
"new_range_ratio": 0.50, # 50% from new range (16-30 chars)
"new_range_min": 16,
"num_epochs": 20,
"learning_rate": 1.2e-4,
},
3: {
"description": "Expansion: medium phrases",
"num_samples": 60_000,
"max_src_length": 50,
"short_mix_ratio": 0.12, # 12% short examples from previous stages
"short_threshold": 30, # ≤30 chars (Stage 1+2)
"new_range_ratio": 0.50, # 50% from new range (31-50 chars)
"new_range_min": 31,
"num_epochs": 20,
"learning_rate": 1e-4,
},
4: {
"description": "Extension: longer phrases",
"num_samples": 120_000, # Increased to better learn multi-word patterns
"max_src_length": 70,
"short_mix_ratio": 0.18, # 18% short examples from previous stages (boosted for retention)
"short_threshold": 50, # ≤50 chars (Stage 1+2+3)
"new_range_ratio": 0.45, # 45% from new range (51-70 chars)
"new_range_min": 51,
"num_epochs": 10,
"learning_rate": 8e-5, # Higher LR to unlearn early-stopping bias from imbalanced data
},
5: {
"description": "Extension: longer sentences",
"num_samples": 150_000, # Increased to better learn multi-word patterns
"max_src_length": 100,
"short_mix_ratio": 0.18, # 18% short examples from previous stages (boosted for retention)
"short_threshold": 70, # ≤70 chars (Stage 1+2+3+4)
"new_range_ratio": 0.45, # 45% from new range (71-100 chars)
"new_range_min": 71,
"num_epochs": 10,
"learning_rate": 5e-5, # Slightly higher to reinforce multi-word patterns
"repetition_penalty": 1.2,
},
6: {
"description": "Full practical corpus: sentences and short paragraphs",
"num_samples": 180_000, # Increased to better learn multi-word patterns
"max_src_length": 150,
"short_mix_ratio": 0.20, # 20% short examples from previous stages (highest retention)
"short_threshold": 100, # ≤100 chars (Stage 1+2+3+4+5)
"new_range_ratio": 0.40, # 40% from new range (101-150 chars)
"new_range_min": 101,
"num_epochs": 10,
"learning_rate": 4e-5, # Fine-tuning polish
"repetition_penalty": 1.2,
},
}
# Early stopping config
EARLY_STOPPING_PATIENCE = 3
EARLY_STOPPING_THRESHOLD = 0.005 # 0.5% improvement threshold
def parse_args():
parser = argparse.ArgumentParser(description="AramT5 Curriculum Learning Trainer")
parser.add_argument(
"--stage",
type=int,
default=1,
choices=[1, 2, 3, 4, 5, 6],
help="Training stage (1=baseline, 2=medium-long, 3=expansion, 4=extension, 5=longer sentences, 6=full practical)",
)
parser.add_argument(
"--hf-model",
type=str,
default=None,
help="HuggingFace model ID to fine-tune (required for stage 2+)",
)
parser.add_argument(
"--west-data",
type=str,
default=DEFAULT_WEST_DATA,
help="Path to West Syriac corpus",
)
parser.add_argument(
"--east-data",
type=str,
default=DEFAULT_EAST_DATA,
help="Path to East Syriac corpus",
)
parser.add_argument(
"--tokeniser",
type=str,
default=DEFAULT_TOKENISER,
help="Path to tokeniser",
)
parser.add_argument(
"--output-dir",
type=str,
default=DEFAULT_OUTPUT_DIR,
help="Output directory for checkpoints",
)
parser.add_argument(
"--batch-size",
type=int,
default=16,
help="Per-device batch size",
)
parser.add_argument(
"--no-early-stopping",
action="store_true",
help="Disable early stopping",
)
parser.add_argument(
"--resume",
type=str,
nargs="?",
const="auto",
default=None,
help="Resume from checkpoint. Use --resume for auto-detect or --resume path/to/checkpoint",
)
return parser.parse_args()
# =============================================================================
# Model Loading
# =============================================================================
def load_model_and_tokeniser(
stage: int = 1,
hf_model: str | None = None,
tokeniser_path: str = DEFAULT_TOKENISER,
):
"""
Load model and tokeniser based on training stage.
Args:
stage: Training stage (1=baseline, 2+=fine-tune from HF)
hf_model: HuggingFace model ID (required for stage 2+)
tokeniser_path: Path to local tokeniser directory
Returns:
Tuple of (model, tokeniser)
"""
tokeniser = T5TokenizerFast.from_pretrained(tokeniser_path)
vocab_size = tokeniser.vocab_size
pad_token_id = tokeniser.pad_token_id
if stage == 1:
# Stage 1: Initialise from scratch with custom config
print("Stage 1: Initialising new model from scratch...")
config = T5Config(
vocab_size=vocab_size,
d_model=512,
d_ff=2048,
num_layers=6,
num_heads=8,
pad_token_id=pad_token_id,
decoder_start_token_id=pad_token_id,
tie_word_embeddings=True,
)
model = T5ForConditionalGeneration(config)
else:
# Stage 2+: Load from HuggingFace
if not hf_model:
raise ValueError(f"Stage {stage} requires --hf-model argument")
print(f"Stage {stage}: Loading model from HuggingFace: {hf_model}")
model = T5ForConditionalGeneration.from_pretrained(hf_model)
return model, tokeniser
# =============================================================================
# Data Processing
# =============================================================================
def get_src_length(example):
"""Extract source text length for curriculum sorting."""
return len(example["transliteration"]["src"])
def create_tokenise_function(tokeniser):
"""Create tokenisation function with closure over tokeniser."""
pad_token_id = tokeniser.pad_token_id
def tokenise_function(example: dict) -> dict:
"""
Tokenise input data with dialect-aware task prefix.
Task prefixes:
- "Syriac2WestLatin: " for West Syriac (Serto)
- "Syriac2EastLatin: " for East Syriac (Madnḥaya)
"""
inputs = []
targets = []
for item in example["transliteration"]:
dialect = item.get("dialect", "west")
if dialect == "east":
prefix = "Syriac2EastLatin: "
else:
prefix = "Syriac2WestLatin: "
inputs.append(f"{prefix}{item['src']}")
targets.append(item["tgt"])
model_inputs = tokeniser(
inputs, max_length=256, truncation=True, padding="max_length"
)
labels = tokeniser(
targets, max_length=256, truncation=True, padding="max_length"
)["input_ids"]
# Replace padding token id with -100 so it's ignored in loss computation
labels = [
[(token if token != pad_token_id else -100) for token in label]
for label in labels
]
model_inputs["labels"] = labels
return model_inputs
return tokenise_function
def load_and_prepare_data(
stage_config: dict,
stage: int = 1,
west_data: str = DEFAULT_WEST_DATA,
east_data: str = DEFAULT_EAST_DATA,
):
"""
Load and prepare data according to curriculum learning stage.
Args:
stage_config: Configuration dict for the current stage
stage: Training stage number (for logging and mixing logic)
west_data: Path to West Syriac corpus JSONL file
east_data: Path to East Syriac corpus JSONL file
Returns:
Tuple of (train_dataset, val_dataset) filtered by sequence length.
"""
print(f"\n{'=' * 60}")
print(f"Stage {stage}: {stage_config['description']}")
print(f"{'=' * 60}\n")
# Load both dialect corpora
print("Loading West Syriac corpus...")
west_dataset = load_dataset("json", data_files=west_data, split="train")
print(f" Loaded {len(west_dataset)} examples")
print("Loading East Syriac corpus...")
east_dataset = load_dataset("json", data_files=east_data, split="train")
print(f" Loaded {len(east_dataset)} examples")
# Combine datasets
full_dataset = concatenate_datasets([west_dataset, east_dataset])
print(f"Total combined: {len(full_dataset)} examples")
# Add source length column for filtering/sorting
full_dataset = full_dataset.map(
lambda x: {"src_length": get_src_length(x)}, num_proc=4
)
# Sort by length (curriculum: short → long)
full_dataset = full_dataset.sort("src_length")
# Apply length filter if specified
max_len = stage_config["max_src_length"]
if max_len is not None:
print(f"\nFiltering to sequences with src_length <= {max_len} characters...")
filtered_dataset = full_dataset.filter(lambda x: x["src_length"] <= max_len)
print(f" After filtering: {len(filtered_dataset)} examples")
else:
filtered_dataset = full_dataset
print("\nNo length filter applied (using all sequences)")
# Sample to target size
num_samples = min(stage_config["num_samples"], len(filtered_dataset))
print(f"\nSampling {num_samples} examples for training...")
# For stages 2+, mix in some short examples to prevent catastrophic forgetting
short_mix_ratio = stage_config["short_mix_ratio"]
middle_oversample = stage_config.get("middle_oversample", False)
if middle_oversample:
# Stage 4 special handling: oversample the rare 15-100 char range
# to build bridge competence before full corpus
num_short = int(num_samples * short_mix_ratio)
num_middle = int(num_samples * 0.40) # 40% from middle range (15-100)
num_main = num_samples - num_short - num_middle
# Short examples (≤15 chars = Stage 1 range) for forgetting mitigation
short_threshold = 15
short_examples = full_dataset.filter(
lambda x: x["src_length"] <= short_threshold
)
short_examples = short_examples.shuffle(seed=42).select(
range(min(num_short, len(short_examples)))
)
print(f" Short examples (≤{short_threshold} chars): {len(short_examples)}")
# Middle-range examples (15-100 chars) - oversample these rare sequences
middle_examples = filtered_dataset.filter(lambda x: 15 < x["src_length"] <= 100)
# Repeat/oversample if needed since these are scarce
if len(middle_examples) < num_middle:
# Repeat the middle examples to reach target
repeats_needed = (num_middle // len(middle_examples)) + 1
middle_repeated = concatenate_datasets([middle_examples] * repeats_needed)
middle_examples = middle_repeated.shuffle(seed=42).select(range(num_middle))
print(
f" Middle-range examples (15-100 chars, oversampled): {len(middle_examples)}"
)
else:
middle_examples = middle_examples.shuffle(seed=42).select(range(num_middle))
print(f" Middle-range examples (15-100 chars): {len(middle_examples)}")
# Main examples from full filtered set
main_examples = filtered_dataset.shuffle(seed=43).select(
range(min(num_main, len(filtered_dataset)))
)
print(f" Main examples: {len(main_examples)}")
# Combine and shuffle
sampled_dataset = concatenate_datasets(
[short_examples, middle_examples, main_examples]
)
sampled_dataset = sampled_dataset.shuffle(seed=42)
elif short_mix_ratio > 0 and stage > 1:
# Stratified sampling: ensure we get examples from the NEW length range
new_range_ratio = stage_config.get("new_range_ratio", 0)
new_range_min = stage_config.get("new_range_min", 0)
num_short = int(num_samples * short_mix_ratio)
if new_range_ratio > 0 and new_range_min > 0:
# Stratified: short + new_range + remainder
num_new_range = int(num_samples * new_range_ratio)
num_remainder = num_samples - num_short - num_new_range
# Short examples = everything from previous stages (for forgetting mitigation)
short_threshold = stage_config.get("short_threshold", 15)
short_examples = full_dataset.filter(
lambda x, thresh=short_threshold: x["src_length"] <= thresh
)
short_examples = short_examples.shuffle(seed=42).select(
range(min(num_short, len(short_examples)))
)
print(
f" Short examples (≤{short_threshold} chars, previous stages): {len(short_examples)}"
)
# New range examples - these are what the model needs to learn
new_range_examples = filtered_dataset.filter(
lambda x, min_len=new_range_min: x["src_length"] >= min_len
)
print(
f" New range pool ({new_range_min}-{max_len} chars): {len(new_range_examples)} available"
)
# Oversample if needed (these are scarce!)
if len(new_range_examples) < num_new_range:
if len(new_range_examples) > 0:
repeats_needed = (num_new_range // len(new_range_examples)) + 1
new_range_repeated = concatenate_datasets(
[new_range_examples] * repeats_needed
)
new_range_examples = new_range_repeated.shuffle(seed=42).select(
range(num_new_range)
)
print(
f" New range examples (oversampled {repeats_needed}x): {len(new_range_examples)}"
)
else:
print(f" WARNING: No examples in new range!")
new_range_examples = full_dataset.filter(lambda x: False) # empty
else:
new_range_examples = new_range_examples.shuffle(seed=42).select(
range(num_new_range)
)
print(f" New range examples: {len(new_range_examples)}")
# Remainder from full filtered set (includes all lengths up to max)
remainder_examples = filtered_dataset.shuffle(seed=43).select(
range(min(num_remainder, len(filtered_dataset)))
)
print(f" Remainder examples: {len(remainder_examples)}")
# Combine and shuffle
sampled_dataset = concatenate_datasets(
[short_examples, new_range_examples, remainder_examples]
)
sampled_dataset = sampled_dataset.shuffle(seed=42)
else:
# Original logic: just short + main
num_main = num_samples - num_short
# Get short examples = everything from previous stages
short_threshold = stage_config.get("short_threshold", 15)
short_examples = full_dataset.filter(
lambda x, thresh=short_threshold: x["src_length"] <= thresh
)
short_examples = short_examples.shuffle(seed=42).select(
range(min(num_short, len(short_examples)))
)
print(
f" Short examples (≤{short_threshold} chars, previous stages): {len(short_examples)}"
)
# Get main examples from filtered dataset
# Apply minimum length filter for main examples in later stages
min_len = stage_config.get("min_src_length", 0)
if min_len > 0:
main_pool = filtered_dataset.filter(
lambda x: x["src_length"] >= min_len
)
print(
f" Main pool after min_length={min_len} filter: {len(main_pool)} examples"
)
else:
main_pool = filtered_dataset
main_examples = main_pool.shuffle(seed=42).select(
range(min(num_main, len(main_pool)))
)
print(f" Main examples: {len(main_examples)}")
# Combine and shuffle
sampled_dataset = concatenate_datasets([short_examples, main_examples])
sampled_dataset = sampled_dataset.shuffle(seed=42)
else:
sampled_dataset = filtered_dataset.shuffle(seed=42).select(range(num_samples))
print(f" Final training pool: {len(sampled_dataset)} examples")
# Split into train/validation (90/10 for stages 4-5, 80/20 for earlier stages)
val_ratio = 0.1 if stage >= 4 else 0.2
dataset_split = sampled_dataset.train_test_split(test_size=val_ratio, seed=42)
train_dataset = dataset_split["train"]
val_dataset = dataset_split["test"]
print(f"\nTrain set: {len(train_dataset)} examples")
print(f"Validation set: {len(val_dataset)} examples")
# Report length statistics
train_lengths = train_dataset["src_length"]
print(f"\nSource length statistics (train):")
print(f" Min: {min(train_lengths)}, Max: {max(train_lengths)}")
print(
f" Mean: {np.mean(train_lengths):.1f}, Median: {np.median(train_lengths):.1f}"
)
return train_dataset, val_dataset
# =============================================================================
# Evaluation Metrics
# =============================================================================
def compute_cer(pred_str: str, target_str: str) -> float:
"""
Compute Character Error Rate between prediction and target.
CER = (substitutions + insertions + deletions) / len(target)
Uses edit distance (Levenshtein distance).
"""
if len(target_str) == 0:
return 0.0 if len(pred_str) == 0 else 1.0
# Simple Levenshtein distance implementation
m, n = len(pred_str), len(target_str)
dp = [[0] * (n + 1) for _ in range(m + 1)]
for i in range(m + 1):
dp[i][0] = i
for j in range(n + 1):
dp[0][j] = j
for i in range(1, m + 1):
for j in range(1, n + 1):
if pred_str[i - 1] == target_str[j - 1]:
dp[i][j] = dp[i - 1][j - 1]
else:
dp[i][j] = 1 + min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1])
return dp[m][n] / len(target_str)
def create_compute_metrics(tokeniser):
"""Create metrics computation function with closure over tokeniser."""
sample_count = [0] # Mutable counter for periodic logging
def compute_metrics(eval_preds):
"""Compute CER and exact match accuracy for evaluation."""
preds, labels = eval_preds
# Replace -100 with pad token for decoding
labels = np.where(labels != -100, labels, tokeniser.pad_token_id)
preds = np.where(preds != -100, preds, tokeniser.pad_token_id)
# Decode to strings
pred_strs = tokeniser.batch_decode(preds, skip_special_tokens=True)
label_strs = tokeniser.batch_decode(labels, skip_special_tokens=True)
# Log sample predictions periodically for debugging
sample_count[0] += 1
if sample_count[0] % 2 == 1: # Every other eval
print("\n--- Sample predictions (first 5) ---")
for i in range(min(5, len(pred_strs))):
print(f" Target: '{label_strs[i]}'")
print(f" Pred: '{pred_strs[i]}'")
print(f" CER: {compute_cer(pred_strs[i], label_strs[i]):.3f}")
print()
# Compute metrics
cer_scores = [
compute_cer(pred, target) for pred, target in zip(pred_strs, label_strs)
]
exact_matches = [
1.0 if pred.strip() == target.strip() else 0.0
for pred, target in zip(pred_strs, label_strs)
]
# Log length statistics
pred_lens = [len(p) for p in pred_strs]
label_lens = [len(l) for l in label_strs]
print(
f" Avg pred len: {np.mean(pred_lens):.1f}, Avg label len: {np.mean(label_lens):.1f}"
)
# Compute length ratio penalty (penalise under-generation)
# Ratio < 1 means output is shorter than target
length_ratios = [
len(pred) / max(len(target), 1)
for pred, target in zip(pred_strs, label_strs)
]
# Penalty: how much shorter outputs are on average (0 = perfect, higher = worse)
# Only penalise under-generation (ratio < 1), not over-generation
length_penalties = [max(0, 1 - ratio) for ratio in length_ratios]
avg_length_penalty = np.mean(length_penalties)
avg_length_ratio = np.mean(length_ratios)
print(
f" Avg length ratio: {avg_length_ratio:.3f}, Avg length penalty: {avg_length_penalty:.3f}"
)
return {
"cer": np.mean(cer_scores),
"exact_match": np.mean(exact_matches),
"length_ratio": avg_length_ratio,
"length_penalty": avg_length_penalty,
}
return compute_metrics
# =============================================================================
# Training
# =============================================================================
def train(args):
"""Main training function implementing curriculum learning."""
# Ensure balanced corpus exists (auto-regenerate if needed)
ensure_balanced_corpus()
stage_config = STAGE_CONFIGS[args.stage]
# Load model and tokeniser
model, tokeniser = load_model_and_tokeniser(
stage=args.stage,
hf_model=args.hf_model,
tokeniser_path=args.tokeniser,
)
# Enable gradient checkpointing to save memory
model.gradient_checkpointing_enable()
# Load and prepare data
train_dataset, val_dataset = load_and_prepare_data(
stage_config=stage_config,
stage=args.stage,
west_data=args.west_data,
east_data=args.east_data,
)
# Tokenise datasets
tokenise_fn = create_tokenise_function(tokeniser)
tokenised_train = train_dataset.map(
tokenise_fn,
batched=True,
remove_columns=train_dataset.column_names,
desc="Tokenising train set",
)
tokenised_eval = val_dataset.map(
tokenise_fn,
batched=True,
remove_columns=val_dataset.column_names,
desc="Tokenising eval set",
)
# Data collator
data_collator = DataCollatorForSeq2Seq(tokenizer=tokeniser, model=model)
# Training arguments
# Stage-specific hyperparameters for better early learning
grad_accum = 4 if args.stage <= 2 else 8 # Smaller effective batch for early stages
label_smooth = 0.05 if args.stage == 1 else (0.08 if args.stage <= 3 else 0.1)
warmup = 0.10 if args.stage == 1 else 0.06 # More warmup for training from scratch
training_args = Seq2SeqTrainingArguments(
output_dir=args.output_dir,
overwrite_output_dir=True,
num_train_epochs=stage_config["num_epochs"],
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
gradient_accumulation_steps=grad_accum,
learning_rate=stage_config["learning_rate"],
warmup_ratio=warmup,
weight_decay=0.01,
label_smoothing_factor=label_smooth,
save_strategy="epoch",
save_total_limit=3,
eval_strategy="epoch",
logging_dir="logs",
fp16=torch.cuda.is_available(),
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
report_to="none",
predict_with_generate=True,
generation_max_length=256, # Generous headroom for all stages
)
# Configure generation settings
# max_length is total sequence length - set high to avoid truncation
model.generation_config.max_length = 256
# Don't use no_repeat_ngram_size - it blocks valid Syriac patterns
# Don't use repetition_penalty - transliteration has legitimate repetition
model.generation_config.eos_token_id = tokeniser.eos_token_id
model.generation_config.pad_token_id = tokeniser.pad_token_id
# Minimum length and length_penalty to discourage under-generation
# Applied to ALL stages with progressive values
model.generation_config.min_length = 2 if args.stage < 5 else 3
# Use beam search with length_penalty to encourage full-length outputs
# Progressive beam size and length penalty by stage
# Increased penalties to counter systematic under-generation (~8% shorter outputs)
if args.stage == 1:
model.generation_config.num_beams = 2
model.generation_config.length_penalty = 1.05 # Slight encouragement from start
elif args.stage == 2:
model.generation_config.num_beams = 2
model.generation_config.length_penalty = 1.12 # Counter under-generation
elif args.stage == 3:
model.generation_config.num_beams = 3
model.generation_config.length_penalty = 1.18
elif args.stage == 4:
model.generation_config.num_beams = 4
model.generation_config.length_penalty = 1.22
else: # stages 5-6
model.generation_config.num_beams = 4
model.generation_config.length_penalty = (
1.25 # >1.0 encourages longer sequences
)
model.generation_config.early_stopping = True
# Callbacks
callbacks = []
if not args.no_early_stopping:
callbacks.append(
EarlyStoppingCallback(
early_stopping_patience=EARLY_STOPPING_PATIENCE,
early_stopping_threshold=EARLY_STOPPING_THRESHOLD,
)
)
print(f"\nEarly stopping enabled:")
print(f" Patience: {EARLY_STOPPING_PATIENCE} evaluations")
print(f" Threshold: {EARLY_STOPPING_THRESHOLD * 100:.1f}% improvement")
# Trainer
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=tokenised_train,
eval_dataset=tokenised_eval,
data_collator=data_collator,
processing_class=tokeniser,
compute_metrics=create_compute_metrics(tokeniser),
callbacks=callbacks,
)
# Train
print(f"\n{'=' * 60}")
print("Starting training...")
print(f"{'=' * 60}\n")
# Handle checkpoint resumption
resume_from_checkpoint = None
if args.resume:
if args.resume == "auto":
# Auto-detect: let Trainer find the last checkpoint
resume_from_checkpoint = True
print("Resuming from last checkpoint (auto-detect)...")
else:
# Specific checkpoint path provided
resume_from_checkpoint = args.resume
print(f"Resuming from checkpoint: {args.resume}")
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
# Save final model
final_output_dir = f"{args.output_dir}/stage{args.stage}-final"
model.save_pretrained(final_output_dir)
tokeniser.save_pretrained(final_output_dir)
print(f"\n{'=' * 60}")
print(f"Stage {args.stage} training complete!")
print(f"Model saved to: {final_output_dir}")
print(f"{'=' * 60}")
# Update README metrics
try:
from update_readme_metrics import (extract_metrics,
find_best_checkpoint,
update_readme_metrics)
checkpoint_dir = find_best_checkpoint(Path(args.output_dir))
if checkpoint_dir:
metrics = extract_metrics(checkpoint_dir)
readme_path = _PROJECT_ROOT / "README.md"
if update_readme_metrics(readme_path, metrics):
print(
f"\nREADME metrics updated automatically from {checkpoint_dir.name}"
)
except Exception as e:
print(f"\nNote: Could not auto-update README metrics: {e}")
# Print next steps
if args.stage <= 6:
print(f"\nNext steps:")
print(
f" 1. Upload model to HuggingFace (e.g., 'your-username/aramt5-v{args.stage}')"
)
print(f" 2. Run stage {args.stage + 1}:")
print(
f" python src/train_t5.py --stage {args.stage + 1} --hf-model your-username/aramt5-v{args.stage}"
)
if __name__ == "__main__":
args = parse_args()
train(args)