""" AramT5 Curriculum Learning Trainer Features: - Curriculum learning: short → long sequences - Catastrophic forgetting mitigation: mixes short examples into later stages - Character Error Rate (CER) evaluation for transliteration quality - Early stopping based on validation loss improvement threshold """ import argparse import subprocess import sys from pathlib import Path import numpy as np import torch from datasets import concatenate_datasets, load_dataset from transformers import (DataCollatorForSeq2Seq, EarlyStoppingCallback, Seq2SeqTrainer, Seq2SeqTrainingArguments, T5Config, T5ForConditionalGeneration, T5TokenizerFast) # ============================================================================= # Configuration # ============================================================================= # Resolve paths relative to project root (parent of src/) _PROJECT_ROOT = Path(__file__).resolve().parent.parent # Default paths (relative to project root) # Use balanced corpus: 40% single, 30% two-word, 30% multi-word # (augmented corpus was 98.5% single, causing truncated multi-word outputs) DEFAULT_WEST_DATA = str(_PROJECT_ROOT / "src/data/syriac_west_balanced_corpus.jsonl") DEFAULT_EAST_DATA = str(_PROJECT_ROOT / "src/data/syriac_east_balanced_corpus.jsonl") # Source files for balancing (input to balance_corpus.py) AUGMENTED_WEST_DATA = _PROJECT_ROOT / "src/data/syriac_west_augmented_corpus.jsonl" AUGMENTED_EAST_DATA = _PROJECT_ROOT / "src/data/syriac_east_augmented_corpus.jsonl" # Source files for augmentation (input to augment_atomic_tokens.py) CLEAN_WEST_DATA = _PROJECT_ROOT / "src/data/syriac_west_clean_corpus.jsonl" CLEAN_EAST_DATA = _PROJECT_ROOT / "src/data/syriac_east_clean_corpus.jsonl" DEFAULT_TOKENISER = str(_PROJECT_ROOT / "src/tokeniser") DEFAULT_OUTPUT_DIR = str(_PROJECT_ROOT / "checkpoints") def ensure_augmented_corpus(): """ Ensure augmented corpus files exist. If augmented corpus is missing or older than clean corpus, regenerate it by running augment_atomic_tokens.py. """ needs_augment = False # Check if augmented files exist if not AUGMENTED_WEST_DATA.exists() or not AUGMENTED_EAST_DATA.exists(): print("Augmented corpus not found, will generate...") needs_augment = True else: # Check if clean files are newer (source changed) if CLEAN_WEST_DATA.exists(): if CLEAN_WEST_DATA.stat().st_mtime > AUGMENTED_WEST_DATA.stat().st_mtime: print("Clean corpus is newer than augmented corpus, regenerating...") needs_augment = True if CLEAN_EAST_DATA.exists(): if CLEAN_EAST_DATA.stat().st_mtime > AUGMENTED_EAST_DATA.stat().st_mtime: print("Clean corpus is newer than augmented corpus, regenerating...") needs_augment = True if needs_augment: augment_script = _PROJECT_ROOT / "src/data/augment_atomic_tokens.py" if not augment_script.exists(): raise FileNotFoundError( f"Cannot regenerate augmented corpus: {augment_script} not found" ) print("Running augment_atomic_tokens.py to generate augmented training data...") result = subprocess.run( [sys.executable, str(augment_script)], cwd=str(_PROJECT_ROOT), capture_output=True, text=True, ) if result.returncode != 0: print(f"Error running augment_atomic_tokens.py:\n{result.stderr}") raise RuntimeError("Failed to generate augmented corpus") print(result.stdout) print("Augmented corpus generated successfully.") else: print("Augmented corpus is up-to-date.") def ensure_balanced_corpus(): """ Ensure balanced corpus files exist. Pipeline: clean_corpus -> augmented_corpus -> balanced_corpus If balanced corpus is missing or older than augmented corpus, regenerate it by running balance_corpus.py. """ # First ensure augmented corpus exists (upstream dependency) ensure_augmented_corpus() west_balanced = Path(DEFAULT_WEST_DATA) east_balanced = Path(DEFAULT_EAST_DATA) needs_rebalance = False # Check if balanced files exist if not west_balanced.exists() or not east_balanced.exists(): print("Balanced corpus not found, will generate...") needs_rebalance = True else: # Check if augmented files are newer (source changed) if AUGMENTED_WEST_DATA.exists(): if AUGMENTED_WEST_DATA.stat().st_mtime > west_balanced.stat().st_mtime: print("Augmented corpus is newer than balanced corpus, regenerating...") needs_rebalance = True if AUGMENTED_EAST_DATA.exists(): if AUGMENTED_EAST_DATA.stat().st_mtime > east_balanced.stat().st_mtime: print("Augmented corpus is newer than balanced corpus, regenerating...") needs_rebalance = True if needs_rebalance: balance_script = _PROJECT_ROOT / "src/data/balance_corpus.py" if not balance_script.exists(): raise FileNotFoundError( f"Cannot regenerate balanced corpus: {balance_script} not found" ) print("Running balance_corpus.py to generate balanced training data...") result = subprocess.run( [sys.executable, str(balance_script)], cwd=str(_PROJECT_ROOT), capture_output=True, text=True, ) if result.returncode != 0: print(f"Error running balance_corpus.py:\n{result.stderr}") raise RuntimeError("Failed to generate balanced corpus") print(result.stdout) print("Balanced corpus generated successfully.") else: print("Balanced corpus is up-to-date.") # Curriculum learning stage configurations STAGE_CONFIGS = { 1: { "description": "Baseline: short sequences only", "num_samples": 20_000, "max_src_length": 15, # Characters in source (short words) "short_mix_ratio": 0.0, # No mixing needed in stage 1 "num_epochs": 30, "learning_rate": 3e-4, }, 2: { "description": "Expansion: short phrases", "num_samples": 40_000, "max_src_length": 30, "short_mix_ratio": 0.12, # 12% short examples from previous stages "short_threshold": 15, # ≤15 chars (Stage 1) "new_range_ratio": 0.50, # 50% from new range (16-30 chars) "new_range_min": 16, "num_epochs": 20, "learning_rate": 1.2e-4, }, 3: { "description": "Expansion: medium phrases", "num_samples": 60_000, "max_src_length": 50, "short_mix_ratio": 0.12, # 12% short examples from previous stages "short_threshold": 30, # ≤30 chars (Stage 1+2) "new_range_ratio": 0.50, # 50% from new range (31-50 chars) "new_range_min": 31, "num_epochs": 20, "learning_rate": 1e-4, }, 4: { "description": "Extension: longer phrases", "num_samples": 120_000, # Increased to better learn multi-word patterns "max_src_length": 70, "short_mix_ratio": 0.18, # 18% short examples from previous stages (boosted for retention) "short_threshold": 50, # ≤50 chars (Stage 1+2+3) "new_range_ratio": 0.45, # 45% from new range (51-70 chars) "new_range_min": 51, "num_epochs": 10, "learning_rate": 8e-5, # Higher LR to unlearn early-stopping bias from imbalanced data }, 5: { "description": "Extension: longer sentences", "num_samples": 150_000, # Increased to better learn multi-word patterns "max_src_length": 100, "short_mix_ratio": 0.18, # 18% short examples from previous stages (boosted for retention) "short_threshold": 70, # ≤70 chars (Stage 1+2+3+4) "new_range_ratio": 0.45, # 45% from new range (71-100 chars) "new_range_min": 71, "num_epochs": 10, "learning_rate": 5e-5, # Slightly higher to reinforce multi-word patterns "repetition_penalty": 1.2, }, 6: { "description": "Full practical corpus: sentences and short paragraphs", "num_samples": 180_000, # Increased to better learn multi-word patterns "max_src_length": 150, "short_mix_ratio": 0.20, # 20% short examples from previous stages (highest retention) "short_threshold": 100, # ≤100 chars (Stage 1+2+3+4+5) "new_range_ratio": 0.40, # 40% from new range (101-150 chars) "new_range_min": 101, "num_epochs": 10, "learning_rate": 4e-5, # Fine-tuning polish "repetition_penalty": 1.2, }, } # Early stopping config EARLY_STOPPING_PATIENCE = 3 EARLY_STOPPING_THRESHOLD = 0.005 # 0.5% improvement threshold def parse_args(): parser = argparse.ArgumentParser(description="AramT5 Curriculum Learning Trainer") parser.add_argument( "--stage", type=int, default=1, choices=[1, 2, 3, 4, 5, 6], help="Training stage (1=baseline, 2=medium-long, 3=expansion, 4=extension, 5=longer sentences, 6=full practical)", ) parser.add_argument( "--hf-model", type=str, default=None, help="HuggingFace model ID to fine-tune (required for stage 2+)", ) parser.add_argument( "--west-data", type=str, default=DEFAULT_WEST_DATA, help="Path to West Syriac corpus", ) parser.add_argument( "--east-data", type=str, default=DEFAULT_EAST_DATA, help="Path to East Syriac corpus", ) parser.add_argument( "--tokeniser", type=str, default=DEFAULT_TOKENISER, help="Path to tokeniser", ) parser.add_argument( "--output-dir", type=str, default=DEFAULT_OUTPUT_DIR, help="Output directory for checkpoints", ) parser.add_argument( "--batch-size", type=int, default=16, help="Per-device batch size", ) parser.add_argument( "--no-early-stopping", action="store_true", help="Disable early stopping", ) parser.add_argument( "--resume", type=str, nargs="?", const="auto", default=None, help="Resume from checkpoint. Use --resume for auto-detect or --resume path/to/checkpoint", ) return parser.parse_args() # ============================================================================= # Model Loading # ============================================================================= def load_model_and_tokeniser( stage: int = 1, hf_model: str | None = None, tokeniser_path: str = DEFAULT_TOKENISER, ): """ Load model and tokeniser based on training stage. Args: stage: Training stage (1=baseline, 2+=fine-tune from HF) hf_model: HuggingFace model ID (required for stage 2+) tokeniser_path: Path to local tokeniser directory Returns: Tuple of (model, tokeniser) """ tokeniser = T5TokenizerFast.from_pretrained(tokeniser_path) vocab_size = tokeniser.vocab_size pad_token_id = tokeniser.pad_token_id if stage == 1: # Stage 1: Initialise from scratch with custom config print("Stage 1: Initialising new model from scratch...") config = T5Config( vocab_size=vocab_size, d_model=512, d_ff=2048, num_layers=6, num_heads=8, pad_token_id=pad_token_id, decoder_start_token_id=pad_token_id, tie_word_embeddings=True, ) model = T5ForConditionalGeneration(config) else: # Stage 2+: Load from HuggingFace if not hf_model: raise ValueError(f"Stage {stage} requires --hf-model argument") print(f"Stage {stage}: Loading model from HuggingFace: {hf_model}") model = T5ForConditionalGeneration.from_pretrained(hf_model) return model, tokeniser # ============================================================================= # Data Processing # ============================================================================= def get_src_length(example): """Extract source text length for curriculum sorting.""" return len(example["transliteration"]["src"]) def create_tokenise_function(tokeniser): """Create tokenisation function with closure over tokeniser.""" pad_token_id = tokeniser.pad_token_id def tokenise_function(example: dict) -> dict: """ Tokenise input data with dialect-aware task prefix. Task prefixes: - "Syriac2WestLatin: " for West Syriac (Serto) - "Syriac2EastLatin: " for East Syriac (Madnḥaya) """ inputs = [] targets = [] for item in example["transliteration"]: dialect = item.get("dialect", "west") if dialect == "east": prefix = "Syriac2EastLatin: " else: prefix = "Syriac2WestLatin: " inputs.append(f"{prefix}{item['src']}") targets.append(item["tgt"]) model_inputs = tokeniser( inputs, max_length=256, truncation=True, padding="max_length" ) labels = tokeniser( targets, max_length=256, truncation=True, padding="max_length" )["input_ids"] # Replace padding token id with -100 so it's ignored in loss computation labels = [ [(token if token != pad_token_id else -100) for token in label] for label in labels ] model_inputs["labels"] = labels return model_inputs return tokenise_function def load_and_prepare_data( stage_config: dict, stage: int = 1, west_data: str = DEFAULT_WEST_DATA, east_data: str = DEFAULT_EAST_DATA, ): """ Load and prepare data according to curriculum learning stage. Args: stage_config: Configuration dict for the current stage stage: Training stage number (for logging and mixing logic) west_data: Path to West Syriac corpus JSONL file east_data: Path to East Syriac corpus JSONL file Returns: Tuple of (train_dataset, val_dataset) filtered by sequence length. """ print(f"\n{'=' * 60}") print(f"Stage {stage}: {stage_config['description']}") print(f"{'=' * 60}\n") # Load both dialect corpora print("Loading West Syriac corpus...") west_dataset = load_dataset("json", data_files=west_data, split="train") print(f" Loaded {len(west_dataset)} examples") print("Loading East Syriac corpus...") east_dataset = load_dataset("json", data_files=east_data, split="train") print(f" Loaded {len(east_dataset)} examples") # Combine datasets full_dataset = concatenate_datasets([west_dataset, east_dataset]) print(f"Total combined: {len(full_dataset)} examples") # Add source length column for filtering/sorting full_dataset = full_dataset.map( lambda x: {"src_length": get_src_length(x)}, num_proc=4 ) # Sort by length (curriculum: short → long) full_dataset = full_dataset.sort("src_length") # Apply length filter if specified max_len = stage_config["max_src_length"] if max_len is not None: print(f"\nFiltering to sequences with src_length <= {max_len} characters...") filtered_dataset = full_dataset.filter(lambda x: x["src_length"] <= max_len) print(f" After filtering: {len(filtered_dataset)} examples") else: filtered_dataset = full_dataset print("\nNo length filter applied (using all sequences)") # Sample to target size num_samples = min(stage_config["num_samples"], len(filtered_dataset)) print(f"\nSampling {num_samples} examples for training...") # For stages 2+, mix in some short examples to prevent catastrophic forgetting short_mix_ratio = stage_config["short_mix_ratio"] middle_oversample = stage_config.get("middle_oversample", False) if middle_oversample: # Stage 4 special handling: oversample the rare 15-100 char range # to build bridge competence before full corpus num_short = int(num_samples * short_mix_ratio) num_middle = int(num_samples * 0.40) # 40% from middle range (15-100) num_main = num_samples - num_short - num_middle # Short examples (≤15 chars = Stage 1 range) for forgetting mitigation short_threshold = 15 short_examples = full_dataset.filter( lambda x: x["src_length"] <= short_threshold ) short_examples = short_examples.shuffle(seed=42).select( range(min(num_short, len(short_examples))) ) print(f" Short examples (≤{short_threshold} chars): {len(short_examples)}") # Middle-range examples (15-100 chars) - oversample these rare sequences middle_examples = filtered_dataset.filter(lambda x: 15 < x["src_length"] <= 100) # Repeat/oversample if needed since these are scarce if len(middle_examples) < num_middle: # Repeat the middle examples to reach target repeats_needed = (num_middle // len(middle_examples)) + 1 middle_repeated = concatenate_datasets([middle_examples] * repeats_needed) middle_examples = middle_repeated.shuffle(seed=42).select(range(num_middle)) print( f" Middle-range examples (15-100 chars, oversampled): {len(middle_examples)}" ) else: middle_examples = middle_examples.shuffle(seed=42).select(range(num_middle)) print(f" Middle-range examples (15-100 chars): {len(middle_examples)}") # Main examples from full filtered set main_examples = filtered_dataset.shuffle(seed=43).select( range(min(num_main, len(filtered_dataset))) ) print(f" Main examples: {len(main_examples)}") # Combine and shuffle sampled_dataset = concatenate_datasets( [short_examples, middle_examples, main_examples] ) sampled_dataset = sampled_dataset.shuffle(seed=42) elif short_mix_ratio > 0 and stage > 1: # Stratified sampling: ensure we get examples from the NEW length range new_range_ratio = stage_config.get("new_range_ratio", 0) new_range_min = stage_config.get("new_range_min", 0) num_short = int(num_samples * short_mix_ratio) if new_range_ratio > 0 and new_range_min > 0: # Stratified: short + new_range + remainder num_new_range = int(num_samples * new_range_ratio) num_remainder = num_samples - num_short - num_new_range # Short examples = everything from previous stages (for forgetting mitigation) short_threshold = stage_config.get("short_threshold", 15) short_examples = full_dataset.filter( lambda x, thresh=short_threshold: x["src_length"] <= thresh ) short_examples = short_examples.shuffle(seed=42).select( range(min(num_short, len(short_examples))) ) print( f" Short examples (≤{short_threshold} chars, previous stages): {len(short_examples)}" ) # New range examples - these are what the model needs to learn new_range_examples = filtered_dataset.filter( lambda x, min_len=new_range_min: x["src_length"] >= min_len ) print( f" New range pool ({new_range_min}-{max_len} chars): {len(new_range_examples)} available" ) # Oversample if needed (these are scarce!) if len(new_range_examples) < num_new_range: if len(new_range_examples) > 0: repeats_needed = (num_new_range // len(new_range_examples)) + 1 new_range_repeated = concatenate_datasets( [new_range_examples] * repeats_needed ) new_range_examples = new_range_repeated.shuffle(seed=42).select( range(num_new_range) ) print( f" New range examples (oversampled {repeats_needed}x): {len(new_range_examples)}" ) else: print(f" WARNING: No examples in new range!") new_range_examples = full_dataset.filter(lambda x: False) # empty else: new_range_examples = new_range_examples.shuffle(seed=42).select( range(num_new_range) ) print(f" New range examples: {len(new_range_examples)}") # Remainder from full filtered set (includes all lengths up to max) remainder_examples = filtered_dataset.shuffle(seed=43).select( range(min(num_remainder, len(filtered_dataset))) ) print(f" Remainder examples: {len(remainder_examples)}") # Combine and shuffle sampled_dataset = concatenate_datasets( [short_examples, new_range_examples, remainder_examples] ) sampled_dataset = sampled_dataset.shuffle(seed=42) else: # Original logic: just short + main num_main = num_samples - num_short # Get short examples = everything from previous stages short_threshold = stage_config.get("short_threshold", 15) short_examples = full_dataset.filter( lambda x, thresh=short_threshold: x["src_length"] <= thresh ) short_examples = short_examples.shuffle(seed=42).select( range(min(num_short, len(short_examples))) ) print( f" Short examples (≤{short_threshold} chars, previous stages): {len(short_examples)}" ) # Get main examples from filtered dataset # Apply minimum length filter for main examples in later stages min_len = stage_config.get("min_src_length", 0) if min_len > 0: main_pool = filtered_dataset.filter( lambda x: x["src_length"] >= min_len ) print( f" Main pool after min_length={min_len} filter: {len(main_pool)} examples" ) else: main_pool = filtered_dataset main_examples = main_pool.shuffle(seed=42).select( range(min(num_main, len(main_pool))) ) print(f" Main examples: {len(main_examples)}") # Combine and shuffle sampled_dataset = concatenate_datasets([short_examples, main_examples]) sampled_dataset = sampled_dataset.shuffle(seed=42) else: sampled_dataset = filtered_dataset.shuffle(seed=42).select(range(num_samples)) print(f" Final training pool: {len(sampled_dataset)} examples") # Split into train/validation (90/10 for stages 4-5, 80/20 for earlier stages) val_ratio = 0.1 if stage >= 4 else 0.2 dataset_split = sampled_dataset.train_test_split(test_size=val_ratio, seed=42) train_dataset = dataset_split["train"] val_dataset = dataset_split["test"] print(f"\nTrain set: {len(train_dataset)} examples") print(f"Validation set: {len(val_dataset)} examples") # Report length statistics train_lengths = train_dataset["src_length"] print(f"\nSource length statistics (train):") print(f" Min: {min(train_lengths)}, Max: {max(train_lengths)}") print( f" Mean: {np.mean(train_lengths):.1f}, Median: {np.median(train_lengths):.1f}" ) return train_dataset, val_dataset # ============================================================================= # Evaluation Metrics # ============================================================================= def compute_cer(pred_str: str, target_str: str) -> float: """ Compute Character Error Rate between prediction and target. CER = (substitutions + insertions + deletions) / len(target) Uses edit distance (Levenshtein distance). """ if len(target_str) == 0: return 0.0 if len(pred_str) == 0 else 1.0 # Simple Levenshtein distance implementation m, n = len(pred_str), len(target_str) dp = [[0] * (n + 1) for _ in range(m + 1)] for i in range(m + 1): dp[i][0] = i for j in range(n + 1): dp[0][j] = j for i in range(1, m + 1): for j in range(1, n + 1): if pred_str[i - 1] == target_str[j - 1]: dp[i][j] = dp[i - 1][j - 1] else: dp[i][j] = 1 + min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1]) return dp[m][n] / len(target_str) def create_compute_metrics(tokeniser): """Create metrics computation function with closure over tokeniser.""" sample_count = [0] # Mutable counter for periodic logging def compute_metrics(eval_preds): """Compute CER and exact match accuracy for evaluation.""" preds, labels = eval_preds # Replace -100 with pad token for decoding labels = np.where(labels != -100, labels, tokeniser.pad_token_id) preds = np.where(preds != -100, preds, tokeniser.pad_token_id) # Decode to strings pred_strs = tokeniser.batch_decode(preds, skip_special_tokens=True) label_strs = tokeniser.batch_decode(labels, skip_special_tokens=True) # Log sample predictions periodically for debugging sample_count[0] += 1 if sample_count[0] % 2 == 1: # Every other eval print("\n--- Sample predictions (first 5) ---") for i in range(min(5, len(pred_strs))): print(f" Target: '{label_strs[i]}'") print(f" Pred: '{pred_strs[i]}'") print(f" CER: {compute_cer(pred_strs[i], label_strs[i]):.3f}") print() # Compute metrics cer_scores = [ compute_cer(pred, target) for pred, target in zip(pred_strs, label_strs) ] exact_matches = [ 1.0 if pred.strip() == target.strip() else 0.0 for pred, target in zip(pred_strs, label_strs) ] # Log length statistics pred_lens = [len(p) for p in pred_strs] label_lens = [len(l) for l in label_strs] print( f" Avg pred len: {np.mean(pred_lens):.1f}, Avg label len: {np.mean(label_lens):.1f}" ) # Compute length ratio penalty (penalise under-generation) # Ratio < 1 means output is shorter than target length_ratios = [ len(pred) / max(len(target), 1) for pred, target in zip(pred_strs, label_strs) ] # Penalty: how much shorter outputs are on average (0 = perfect, higher = worse) # Only penalise under-generation (ratio < 1), not over-generation length_penalties = [max(0, 1 - ratio) for ratio in length_ratios] avg_length_penalty = np.mean(length_penalties) avg_length_ratio = np.mean(length_ratios) print( f" Avg length ratio: {avg_length_ratio:.3f}, Avg length penalty: {avg_length_penalty:.3f}" ) return { "cer": np.mean(cer_scores), "exact_match": np.mean(exact_matches), "length_ratio": avg_length_ratio, "length_penalty": avg_length_penalty, } return compute_metrics # ============================================================================= # Training # ============================================================================= def train(args): """Main training function implementing curriculum learning.""" # Ensure balanced corpus exists (auto-regenerate if needed) ensure_balanced_corpus() stage_config = STAGE_CONFIGS[args.stage] # Load model and tokeniser model, tokeniser = load_model_and_tokeniser( stage=args.stage, hf_model=args.hf_model, tokeniser_path=args.tokeniser, ) # Enable gradient checkpointing to save memory model.gradient_checkpointing_enable() # Load and prepare data train_dataset, val_dataset = load_and_prepare_data( stage_config=stage_config, stage=args.stage, west_data=args.west_data, east_data=args.east_data, ) # Tokenise datasets tokenise_fn = create_tokenise_function(tokeniser) tokenised_train = train_dataset.map( tokenise_fn, batched=True, remove_columns=train_dataset.column_names, desc="Tokenising train set", ) tokenised_eval = val_dataset.map( tokenise_fn, batched=True, remove_columns=val_dataset.column_names, desc="Tokenising eval set", ) # Data collator data_collator = DataCollatorForSeq2Seq(tokenizer=tokeniser, model=model) # Training arguments # Stage-specific hyperparameters for better early learning grad_accum = 4 if args.stage <= 2 else 8 # Smaller effective batch for early stages label_smooth = 0.05 if args.stage == 1 else (0.08 if args.stage <= 3 else 0.1) warmup = 0.10 if args.stage == 1 else 0.06 # More warmup for training from scratch training_args = Seq2SeqTrainingArguments( output_dir=args.output_dir, overwrite_output_dir=True, num_train_epochs=stage_config["num_epochs"], per_device_train_batch_size=args.batch_size, per_device_eval_batch_size=args.batch_size, gradient_accumulation_steps=grad_accum, learning_rate=stage_config["learning_rate"], warmup_ratio=warmup, weight_decay=0.01, label_smoothing_factor=label_smooth, save_strategy="epoch", save_total_limit=3, eval_strategy="epoch", logging_dir="logs", fp16=torch.cuda.is_available(), load_best_model_at_end=True, metric_for_best_model="eval_loss", greater_is_better=False, report_to="none", predict_with_generate=True, generation_max_length=256, # Generous headroom for all stages ) # Configure generation settings # max_length is total sequence length - set high to avoid truncation model.generation_config.max_length = 256 # Don't use no_repeat_ngram_size - it blocks valid Syriac patterns # Don't use repetition_penalty - transliteration has legitimate repetition model.generation_config.eos_token_id = tokeniser.eos_token_id model.generation_config.pad_token_id = tokeniser.pad_token_id # Minimum length and length_penalty to discourage under-generation # Applied to ALL stages with progressive values model.generation_config.min_length = 2 if args.stage < 5 else 3 # Use beam search with length_penalty to encourage full-length outputs # Progressive beam size and length penalty by stage # Increased penalties to counter systematic under-generation (~8% shorter outputs) if args.stage == 1: model.generation_config.num_beams = 2 model.generation_config.length_penalty = 1.05 # Slight encouragement from start elif args.stage == 2: model.generation_config.num_beams = 2 model.generation_config.length_penalty = 1.12 # Counter under-generation elif args.stage == 3: model.generation_config.num_beams = 3 model.generation_config.length_penalty = 1.18 elif args.stage == 4: model.generation_config.num_beams = 4 model.generation_config.length_penalty = 1.22 else: # stages 5-6 model.generation_config.num_beams = 4 model.generation_config.length_penalty = ( 1.25 # >1.0 encourages longer sequences ) model.generation_config.early_stopping = True # Callbacks callbacks = [] if not args.no_early_stopping: callbacks.append( EarlyStoppingCallback( early_stopping_patience=EARLY_STOPPING_PATIENCE, early_stopping_threshold=EARLY_STOPPING_THRESHOLD, ) ) print(f"\nEarly stopping enabled:") print(f" Patience: {EARLY_STOPPING_PATIENCE} evaluations") print(f" Threshold: {EARLY_STOPPING_THRESHOLD * 100:.1f}% improvement") # Trainer trainer = Seq2SeqTrainer( model=model, args=training_args, train_dataset=tokenised_train, eval_dataset=tokenised_eval, data_collator=data_collator, processing_class=tokeniser, compute_metrics=create_compute_metrics(tokeniser), callbacks=callbacks, ) # Train print(f"\n{'=' * 60}") print("Starting training...") print(f"{'=' * 60}\n") # Handle checkpoint resumption resume_from_checkpoint = None if args.resume: if args.resume == "auto": # Auto-detect: let Trainer find the last checkpoint resume_from_checkpoint = True print("Resuming from last checkpoint (auto-detect)...") else: # Specific checkpoint path provided resume_from_checkpoint = args.resume print(f"Resuming from checkpoint: {args.resume}") trainer.train(resume_from_checkpoint=resume_from_checkpoint) # Save final model final_output_dir = f"{args.output_dir}/stage{args.stage}-final" model.save_pretrained(final_output_dir) tokeniser.save_pretrained(final_output_dir) print(f"\n{'=' * 60}") print(f"Stage {args.stage} training complete!") print(f"Model saved to: {final_output_dir}") print(f"{'=' * 60}") # Update README metrics try: from update_readme_metrics import (extract_metrics, find_best_checkpoint, update_readme_metrics) checkpoint_dir = find_best_checkpoint(Path(args.output_dir)) if checkpoint_dir: metrics = extract_metrics(checkpoint_dir) readme_path = _PROJECT_ROOT / "README.md" if update_readme_metrics(readme_path, metrics): print( f"\nREADME metrics updated automatically from {checkpoint_dir.name}" ) except Exception as e: print(f"\nNote: Could not auto-update README metrics: {e}") # Print next steps if args.stage <= 6: print(f"\nNext steps:") print( f" 1. Upload model to HuggingFace (e.g., 'your-username/aramt5-v{args.stage}')" ) print(f" 2. Run stage {args.stage + 1}:") print( f" python src/train_t5.py --stage {args.stage + 1} --hf-model your-username/aramt5-v{args.stage}" ) if __name__ == "__main__": args = parse_args() train(args)