| """ |
| AramT5 Curriculum Learning Trainer |
| |
| Features: |
| - Curriculum learning: short → long sequences |
| - Catastrophic forgetting mitigation: mixes short examples into later stages |
| - Character Error Rate (CER) evaluation for transliteration quality |
| - Early stopping based on validation loss improvement threshold |
| """ |
|
|
| import argparse |
| import subprocess |
| import sys |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| from datasets import concatenate_datasets, load_dataset |
| from transformers import (DataCollatorForSeq2Seq, EarlyStoppingCallback, |
| Seq2SeqTrainer, Seq2SeqTrainingArguments, T5Config, |
| T5ForConditionalGeneration, T5TokenizerFast) |
|
|
| |
| |
| |
|
|
| |
| _PROJECT_ROOT = Path(__file__).resolve().parent.parent |
|
|
| |
| |
| |
| DEFAULT_WEST_DATA = str(_PROJECT_ROOT / "src/data/syriac_west_balanced_corpus.jsonl") |
| DEFAULT_EAST_DATA = str(_PROJECT_ROOT / "src/data/syriac_east_balanced_corpus.jsonl") |
| |
| AUGMENTED_WEST_DATA = _PROJECT_ROOT / "src/data/syriac_west_augmented_corpus.jsonl" |
| AUGMENTED_EAST_DATA = _PROJECT_ROOT / "src/data/syriac_east_augmented_corpus.jsonl" |
| |
| CLEAN_WEST_DATA = _PROJECT_ROOT / "src/data/syriac_west_clean_corpus.jsonl" |
| CLEAN_EAST_DATA = _PROJECT_ROOT / "src/data/syriac_east_clean_corpus.jsonl" |
| DEFAULT_TOKENISER = str(_PROJECT_ROOT / "src/tokeniser") |
| DEFAULT_OUTPUT_DIR = str(_PROJECT_ROOT / "checkpoints") |
|
|
|
|
| def ensure_augmented_corpus(): |
| """ |
| Ensure augmented corpus files exist. |
| |
| If augmented corpus is missing or older than clean corpus, |
| regenerate it by running augment_atomic_tokens.py. |
| """ |
| needs_augment = False |
|
|
| |
| if not AUGMENTED_WEST_DATA.exists() or not AUGMENTED_EAST_DATA.exists(): |
| print("Augmented corpus not found, will generate...") |
| needs_augment = True |
| else: |
| |
| if CLEAN_WEST_DATA.exists(): |
| if CLEAN_WEST_DATA.stat().st_mtime > AUGMENTED_WEST_DATA.stat().st_mtime: |
| print("Clean corpus is newer than augmented corpus, regenerating...") |
| needs_augment = True |
| if CLEAN_EAST_DATA.exists(): |
| if CLEAN_EAST_DATA.stat().st_mtime > AUGMENTED_EAST_DATA.stat().st_mtime: |
| print("Clean corpus is newer than augmented corpus, regenerating...") |
| needs_augment = True |
|
|
| if needs_augment: |
| augment_script = _PROJECT_ROOT / "src/data/augment_atomic_tokens.py" |
| if not augment_script.exists(): |
| raise FileNotFoundError( |
| f"Cannot regenerate augmented corpus: {augment_script} not found" |
| ) |
|
|
| print("Running augment_atomic_tokens.py to generate augmented training data...") |
| result = subprocess.run( |
| [sys.executable, str(augment_script)], |
| cwd=str(_PROJECT_ROOT), |
| capture_output=True, |
| text=True, |
| ) |
|
|
| if result.returncode != 0: |
| print(f"Error running augment_atomic_tokens.py:\n{result.stderr}") |
| raise RuntimeError("Failed to generate augmented corpus") |
|
|
| print(result.stdout) |
| print("Augmented corpus generated successfully.") |
| else: |
| print("Augmented corpus is up-to-date.") |
|
|
|
|
| def ensure_balanced_corpus(): |
| """ |
| Ensure balanced corpus files exist. |
| |
| Pipeline: clean_corpus -> augmented_corpus -> balanced_corpus |
| |
| If balanced corpus is missing or older than augmented corpus, |
| regenerate it by running balance_corpus.py. |
| """ |
| |
| ensure_augmented_corpus() |
|
|
| west_balanced = Path(DEFAULT_WEST_DATA) |
| east_balanced = Path(DEFAULT_EAST_DATA) |
|
|
| needs_rebalance = False |
|
|
| |
| if not west_balanced.exists() or not east_balanced.exists(): |
| print("Balanced corpus not found, will generate...") |
| needs_rebalance = True |
| else: |
| |
| if AUGMENTED_WEST_DATA.exists(): |
| if AUGMENTED_WEST_DATA.stat().st_mtime > west_balanced.stat().st_mtime: |
| print("Augmented corpus is newer than balanced corpus, regenerating...") |
| needs_rebalance = True |
| if AUGMENTED_EAST_DATA.exists(): |
| if AUGMENTED_EAST_DATA.stat().st_mtime > east_balanced.stat().st_mtime: |
| print("Augmented corpus is newer than balanced corpus, regenerating...") |
| needs_rebalance = True |
|
|
| if needs_rebalance: |
| balance_script = _PROJECT_ROOT / "src/data/balance_corpus.py" |
| if not balance_script.exists(): |
| raise FileNotFoundError( |
| f"Cannot regenerate balanced corpus: {balance_script} not found" |
| ) |
|
|
| print("Running balance_corpus.py to generate balanced training data...") |
| result = subprocess.run( |
| [sys.executable, str(balance_script)], |
| cwd=str(_PROJECT_ROOT), |
| capture_output=True, |
| text=True, |
| ) |
|
|
| if result.returncode != 0: |
| print(f"Error running balance_corpus.py:\n{result.stderr}") |
| raise RuntimeError("Failed to generate balanced corpus") |
|
|
| print(result.stdout) |
| print("Balanced corpus generated successfully.") |
| else: |
| print("Balanced corpus is up-to-date.") |
|
|
|
|
| |
| STAGE_CONFIGS = { |
| 1: { |
| "description": "Baseline: short sequences only", |
| "num_samples": 20_000, |
| "max_src_length": 15, |
| "short_mix_ratio": 0.0, |
| "num_epochs": 30, |
| "learning_rate": 3e-4, |
| }, |
| 2: { |
| "description": "Expansion: short phrases", |
| "num_samples": 40_000, |
| "max_src_length": 30, |
| "short_mix_ratio": 0.12, |
| "short_threshold": 15, |
| "new_range_ratio": 0.50, |
| "new_range_min": 16, |
| "num_epochs": 20, |
| "learning_rate": 1.2e-4, |
| }, |
| 3: { |
| "description": "Expansion: medium phrases", |
| "num_samples": 60_000, |
| "max_src_length": 50, |
| "short_mix_ratio": 0.12, |
| "short_threshold": 30, |
| "new_range_ratio": 0.50, |
| "new_range_min": 31, |
| "num_epochs": 20, |
| "learning_rate": 1e-4, |
| }, |
| 4: { |
| "description": "Extension: longer phrases", |
| "num_samples": 120_000, |
| "max_src_length": 70, |
| "short_mix_ratio": 0.18, |
| "short_threshold": 50, |
| "new_range_ratio": 0.45, |
| "new_range_min": 51, |
| "num_epochs": 10, |
| "learning_rate": 8e-5, |
| }, |
| 5: { |
| "description": "Extension: longer sentences", |
| "num_samples": 150_000, |
| "max_src_length": 100, |
| "short_mix_ratio": 0.18, |
| "short_threshold": 70, |
| "new_range_ratio": 0.45, |
| "new_range_min": 71, |
| "num_epochs": 10, |
| "learning_rate": 5e-5, |
| "repetition_penalty": 1.2, |
| }, |
| 6: { |
| "description": "Full practical corpus: sentences and short paragraphs", |
| "num_samples": 180_000, |
| "max_src_length": 150, |
| "short_mix_ratio": 0.20, |
| "short_threshold": 100, |
| "new_range_ratio": 0.40, |
| "new_range_min": 101, |
| "num_epochs": 10, |
| "learning_rate": 4e-5, |
| "repetition_penalty": 1.2, |
| }, |
| } |
|
|
| |
| EARLY_STOPPING_PATIENCE = 3 |
| EARLY_STOPPING_THRESHOLD = 0.005 |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description="AramT5 Curriculum Learning Trainer") |
| parser.add_argument( |
| "--stage", |
| type=int, |
| default=1, |
| choices=[1, 2, 3, 4, 5, 6], |
| help="Training stage (1=baseline, 2=medium-long, 3=expansion, 4=extension, 5=longer sentences, 6=full practical)", |
| ) |
| parser.add_argument( |
| "--hf-model", |
| type=str, |
| default=None, |
| help="HuggingFace model ID to fine-tune (required for stage 2+)", |
| ) |
| parser.add_argument( |
| "--west-data", |
| type=str, |
| default=DEFAULT_WEST_DATA, |
| help="Path to West Syriac corpus", |
| ) |
| parser.add_argument( |
| "--east-data", |
| type=str, |
| default=DEFAULT_EAST_DATA, |
| help="Path to East Syriac corpus", |
| ) |
| parser.add_argument( |
| "--tokeniser", |
| type=str, |
| default=DEFAULT_TOKENISER, |
| help="Path to tokeniser", |
| ) |
| parser.add_argument( |
| "--output-dir", |
| type=str, |
| default=DEFAULT_OUTPUT_DIR, |
| help="Output directory for checkpoints", |
| ) |
| parser.add_argument( |
| "--batch-size", |
| type=int, |
| default=16, |
| help="Per-device batch size", |
| ) |
| parser.add_argument( |
| "--no-early-stopping", |
| action="store_true", |
| help="Disable early stopping", |
| ) |
| parser.add_argument( |
| "--resume", |
| type=str, |
| nargs="?", |
| const="auto", |
| default=None, |
| help="Resume from checkpoint. Use --resume for auto-detect or --resume path/to/checkpoint", |
| ) |
| return parser.parse_args() |
|
|
|
|
| |
| |
| |
|
|
|
|
| def load_model_and_tokeniser( |
| stage: int = 1, |
| hf_model: str | None = None, |
| tokeniser_path: str = DEFAULT_TOKENISER, |
| ): |
| """ |
| Load model and tokeniser based on training stage. |
| |
| Args: |
| stage: Training stage (1=baseline, 2+=fine-tune from HF) |
| hf_model: HuggingFace model ID (required for stage 2+) |
| tokeniser_path: Path to local tokeniser directory |
| |
| Returns: |
| Tuple of (model, tokeniser) |
| """ |
| tokeniser = T5TokenizerFast.from_pretrained(tokeniser_path) |
| vocab_size = tokeniser.vocab_size |
| pad_token_id = tokeniser.pad_token_id |
|
|
| if stage == 1: |
| |
| print("Stage 1: Initialising new model from scratch...") |
| config = T5Config( |
| vocab_size=vocab_size, |
| d_model=512, |
| d_ff=2048, |
| num_layers=6, |
| num_heads=8, |
| pad_token_id=pad_token_id, |
| decoder_start_token_id=pad_token_id, |
| tie_word_embeddings=True, |
| ) |
| model = T5ForConditionalGeneration(config) |
| else: |
| |
| if not hf_model: |
| raise ValueError(f"Stage {stage} requires --hf-model argument") |
| print(f"Stage {stage}: Loading model from HuggingFace: {hf_model}") |
| model = T5ForConditionalGeneration.from_pretrained(hf_model) |
|
|
| return model, tokeniser |
|
|
|
|
| |
| |
| |
|
|
|
|
| def get_src_length(example): |
| """Extract source text length for curriculum sorting.""" |
| return len(example["transliteration"]["src"]) |
|
|
|
|
| def create_tokenise_function(tokeniser): |
| """Create tokenisation function with closure over tokeniser.""" |
| pad_token_id = tokeniser.pad_token_id |
|
|
| def tokenise_function(example: dict) -> dict: |
| """ |
| Tokenise input data with dialect-aware task prefix. |
| |
| Task prefixes: |
| - "Syriac2WestLatin: " for West Syriac (Serto) |
| - "Syriac2EastLatin: " for East Syriac (Madnḥaya) |
| """ |
| inputs = [] |
| targets = [] |
|
|
| for item in example["transliteration"]: |
| dialect = item.get("dialect", "west") |
| if dialect == "east": |
| prefix = "Syriac2EastLatin: " |
| else: |
| prefix = "Syriac2WestLatin: " |
|
|
| inputs.append(f"{prefix}{item['src']}") |
| targets.append(item["tgt"]) |
|
|
| model_inputs = tokeniser( |
| inputs, max_length=256, truncation=True, padding="max_length" |
| ) |
| labels = tokeniser( |
| targets, max_length=256, truncation=True, padding="max_length" |
| )["input_ids"] |
|
|
| |
| labels = [ |
| [(token if token != pad_token_id else -100) for token in label] |
| for label in labels |
| ] |
| model_inputs["labels"] = labels |
|
|
| return model_inputs |
|
|
| return tokenise_function |
|
|
|
|
| def load_and_prepare_data( |
| stage_config: dict, |
| stage: int = 1, |
| west_data: str = DEFAULT_WEST_DATA, |
| east_data: str = DEFAULT_EAST_DATA, |
| ): |
| """ |
| Load and prepare data according to curriculum learning stage. |
| |
| Args: |
| stage_config: Configuration dict for the current stage |
| stage: Training stage number (for logging and mixing logic) |
| west_data: Path to West Syriac corpus JSONL file |
| east_data: Path to East Syriac corpus JSONL file |
| |
| Returns: |
| Tuple of (train_dataset, val_dataset) filtered by sequence length. |
| """ |
| print(f"\n{'=' * 60}") |
| print(f"Stage {stage}: {stage_config['description']}") |
| print(f"{'=' * 60}\n") |
|
|
| |
| print("Loading West Syriac corpus...") |
| west_dataset = load_dataset("json", data_files=west_data, split="train") |
| print(f" Loaded {len(west_dataset)} examples") |
|
|
| print("Loading East Syriac corpus...") |
| east_dataset = load_dataset("json", data_files=east_data, split="train") |
| print(f" Loaded {len(east_dataset)} examples") |
|
|
| |
| full_dataset = concatenate_datasets([west_dataset, east_dataset]) |
| print(f"Total combined: {len(full_dataset)} examples") |
|
|
| |
| full_dataset = full_dataset.map( |
| lambda x: {"src_length": get_src_length(x)}, num_proc=4 |
| ) |
|
|
| |
| full_dataset = full_dataset.sort("src_length") |
|
|
| |
| max_len = stage_config["max_src_length"] |
| if max_len is not None: |
| print(f"\nFiltering to sequences with src_length <= {max_len} characters...") |
| filtered_dataset = full_dataset.filter(lambda x: x["src_length"] <= max_len) |
| print(f" After filtering: {len(filtered_dataset)} examples") |
| else: |
| filtered_dataset = full_dataset |
| print("\nNo length filter applied (using all sequences)") |
|
|
| |
| num_samples = min(stage_config["num_samples"], len(filtered_dataset)) |
| print(f"\nSampling {num_samples} examples for training...") |
|
|
| |
| short_mix_ratio = stage_config["short_mix_ratio"] |
| middle_oversample = stage_config.get("middle_oversample", False) |
|
|
| if middle_oversample: |
| |
| |
| num_short = int(num_samples * short_mix_ratio) |
| num_middle = int(num_samples * 0.40) |
| num_main = num_samples - num_short - num_middle |
|
|
| |
| short_threshold = 15 |
| short_examples = full_dataset.filter( |
| lambda x: x["src_length"] <= short_threshold |
| ) |
| short_examples = short_examples.shuffle(seed=42).select( |
| range(min(num_short, len(short_examples))) |
| ) |
| print(f" Short examples (≤{short_threshold} chars): {len(short_examples)}") |
|
|
| |
| middle_examples = filtered_dataset.filter(lambda x: 15 < x["src_length"] <= 100) |
| |
| if len(middle_examples) < num_middle: |
| |
| repeats_needed = (num_middle // len(middle_examples)) + 1 |
| middle_repeated = concatenate_datasets([middle_examples] * repeats_needed) |
| middle_examples = middle_repeated.shuffle(seed=42).select(range(num_middle)) |
| print( |
| f" Middle-range examples (15-100 chars, oversampled): {len(middle_examples)}" |
| ) |
| else: |
| middle_examples = middle_examples.shuffle(seed=42).select(range(num_middle)) |
| print(f" Middle-range examples (15-100 chars): {len(middle_examples)}") |
|
|
| |
| main_examples = filtered_dataset.shuffle(seed=43).select( |
| range(min(num_main, len(filtered_dataset))) |
| ) |
| print(f" Main examples: {len(main_examples)}") |
|
|
| |
| sampled_dataset = concatenate_datasets( |
| [short_examples, middle_examples, main_examples] |
| ) |
| sampled_dataset = sampled_dataset.shuffle(seed=42) |
|
|
| elif short_mix_ratio > 0 and stage > 1: |
| |
| new_range_ratio = stage_config.get("new_range_ratio", 0) |
| new_range_min = stage_config.get("new_range_min", 0) |
|
|
| num_short = int(num_samples * short_mix_ratio) |
|
|
| if new_range_ratio > 0 and new_range_min > 0: |
| |
| num_new_range = int(num_samples * new_range_ratio) |
| num_remainder = num_samples - num_short - num_new_range |
|
|
| |
| short_threshold = stage_config.get("short_threshold", 15) |
| short_examples = full_dataset.filter( |
| lambda x, thresh=short_threshold: x["src_length"] <= thresh |
| ) |
| short_examples = short_examples.shuffle(seed=42).select( |
| range(min(num_short, len(short_examples))) |
| ) |
| print( |
| f" Short examples (≤{short_threshold} chars, previous stages): {len(short_examples)}" |
| ) |
|
|
| |
| new_range_examples = filtered_dataset.filter( |
| lambda x, min_len=new_range_min: x["src_length"] >= min_len |
| ) |
| print( |
| f" New range pool ({new_range_min}-{max_len} chars): {len(new_range_examples)} available" |
| ) |
|
|
| |
| if len(new_range_examples) < num_new_range: |
| if len(new_range_examples) > 0: |
| repeats_needed = (num_new_range // len(new_range_examples)) + 1 |
| new_range_repeated = concatenate_datasets( |
| [new_range_examples] * repeats_needed |
| ) |
| new_range_examples = new_range_repeated.shuffle(seed=42).select( |
| range(num_new_range) |
| ) |
| print( |
| f" New range examples (oversampled {repeats_needed}x): {len(new_range_examples)}" |
| ) |
| else: |
| print(f" WARNING: No examples in new range!") |
| new_range_examples = full_dataset.filter(lambda x: False) |
| else: |
| new_range_examples = new_range_examples.shuffle(seed=42).select( |
| range(num_new_range) |
| ) |
| print(f" New range examples: {len(new_range_examples)}") |
|
|
| |
| remainder_examples = filtered_dataset.shuffle(seed=43).select( |
| range(min(num_remainder, len(filtered_dataset))) |
| ) |
| print(f" Remainder examples: {len(remainder_examples)}") |
|
|
| |
| sampled_dataset = concatenate_datasets( |
| [short_examples, new_range_examples, remainder_examples] |
| ) |
| sampled_dataset = sampled_dataset.shuffle(seed=42) |
| else: |
| |
| num_main = num_samples - num_short |
|
|
| |
| short_threshold = stage_config.get("short_threshold", 15) |
| short_examples = full_dataset.filter( |
| lambda x, thresh=short_threshold: x["src_length"] <= thresh |
| ) |
| short_examples = short_examples.shuffle(seed=42).select( |
| range(min(num_short, len(short_examples))) |
| ) |
| print( |
| f" Short examples (≤{short_threshold} chars, previous stages): {len(short_examples)}" |
| ) |
|
|
| |
| |
| min_len = stage_config.get("min_src_length", 0) |
| if min_len > 0: |
| main_pool = filtered_dataset.filter( |
| lambda x: x["src_length"] >= min_len |
| ) |
| print( |
| f" Main pool after min_length={min_len} filter: {len(main_pool)} examples" |
| ) |
| else: |
| main_pool = filtered_dataset |
|
|
| main_examples = main_pool.shuffle(seed=42).select( |
| range(min(num_main, len(main_pool))) |
| ) |
| print(f" Main examples: {len(main_examples)}") |
|
|
| |
| sampled_dataset = concatenate_datasets([short_examples, main_examples]) |
| sampled_dataset = sampled_dataset.shuffle(seed=42) |
| else: |
| sampled_dataset = filtered_dataset.shuffle(seed=42).select(range(num_samples)) |
|
|
| print(f" Final training pool: {len(sampled_dataset)} examples") |
|
|
| |
| val_ratio = 0.1 if stage >= 4 else 0.2 |
| dataset_split = sampled_dataset.train_test_split(test_size=val_ratio, seed=42) |
| train_dataset = dataset_split["train"] |
| val_dataset = dataset_split["test"] |
|
|
| print(f"\nTrain set: {len(train_dataset)} examples") |
| print(f"Validation set: {len(val_dataset)} examples") |
|
|
| |
| train_lengths = train_dataset["src_length"] |
| print(f"\nSource length statistics (train):") |
| print(f" Min: {min(train_lengths)}, Max: {max(train_lengths)}") |
| print( |
| f" Mean: {np.mean(train_lengths):.1f}, Median: {np.median(train_lengths):.1f}" |
| ) |
|
|
| return train_dataset, val_dataset |
|
|
|
|
| |
| |
| |
|
|
|
|
| def compute_cer(pred_str: str, target_str: str) -> float: |
| """ |
| Compute Character Error Rate between prediction and target. |
| |
| CER = (substitutions + insertions + deletions) / len(target) |
| Uses edit distance (Levenshtein distance). |
| """ |
| if len(target_str) == 0: |
| return 0.0 if len(pred_str) == 0 else 1.0 |
|
|
| |
| m, n = len(pred_str), len(target_str) |
| dp = [[0] * (n + 1) for _ in range(m + 1)] |
|
|
| for i in range(m + 1): |
| dp[i][0] = i |
| for j in range(n + 1): |
| dp[0][j] = j |
|
|
| for i in range(1, m + 1): |
| for j in range(1, n + 1): |
| if pred_str[i - 1] == target_str[j - 1]: |
| dp[i][j] = dp[i - 1][j - 1] |
| else: |
| dp[i][j] = 1 + min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1]) |
|
|
| return dp[m][n] / len(target_str) |
|
|
|
|
| def create_compute_metrics(tokeniser): |
| """Create metrics computation function with closure over tokeniser.""" |
| sample_count = [0] |
|
|
| def compute_metrics(eval_preds): |
| """Compute CER and exact match accuracy for evaluation.""" |
| preds, labels = eval_preds |
|
|
| |
| labels = np.where(labels != -100, labels, tokeniser.pad_token_id) |
| preds = np.where(preds != -100, preds, tokeniser.pad_token_id) |
|
|
| |
| pred_strs = tokeniser.batch_decode(preds, skip_special_tokens=True) |
| label_strs = tokeniser.batch_decode(labels, skip_special_tokens=True) |
|
|
| |
| sample_count[0] += 1 |
| if sample_count[0] % 2 == 1: |
| print("\n--- Sample predictions (first 5) ---") |
| for i in range(min(5, len(pred_strs))): |
| print(f" Target: '{label_strs[i]}'") |
| print(f" Pred: '{pred_strs[i]}'") |
| print(f" CER: {compute_cer(pred_strs[i], label_strs[i]):.3f}") |
| print() |
|
|
| |
| cer_scores = [ |
| compute_cer(pred, target) for pred, target in zip(pred_strs, label_strs) |
| ] |
| exact_matches = [ |
| 1.0 if pred.strip() == target.strip() else 0.0 |
| for pred, target in zip(pred_strs, label_strs) |
| ] |
|
|
| |
| pred_lens = [len(p) for p in pred_strs] |
| label_lens = [len(l) for l in label_strs] |
| print( |
| f" Avg pred len: {np.mean(pred_lens):.1f}, Avg label len: {np.mean(label_lens):.1f}" |
| ) |
|
|
| |
| |
| length_ratios = [ |
| len(pred) / max(len(target), 1) |
| for pred, target in zip(pred_strs, label_strs) |
| ] |
| |
| |
| length_penalties = [max(0, 1 - ratio) for ratio in length_ratios] |
| avg_length_penalty = np.mean(length_penalties) |
| avg_length_ratio = np.mean(length_ratios) |
| print( |
| f" Avg length ratio: {avg_length_ratio:.3f}, Avg length penalty: {avg_length_penalty:.3f}" |
| ) |
|
|
| return { |
| "cer": np.mean(cer_scores), |
| "exact_match": np.mean(exact_matches), |
| "length_ratio": avg_length_ratio, |
| "length_penalty": avg_length_penalty, |
| } |
|
|
| return compute_metrics |
|
|
|
|
| |
| |
| |
|
|
|
|
| def train(args): |
| """Main training function implementing curriculum learning.""" |
| |
| ensure_balanced_corpus() |
|
|
| stage_config = STAGE_CONFIGS[args.stage] |
|
|
| |
| model, tokeniser = load_model_and_tokeniser( |
| stage=args.stage, |
| hf_model=args.hf_model, |
| tokeniser_path=args.tokeniser, |
| ) |
|
|
| |
| model.gradient_checkpointing_enable() |
|
|
| |
| train_dataset, val_dataset = load_and_prepare_data( |
| stage_config=stage_config, |
| stage=args.stage, |
| west_data=args.west_data, |
| east_data=args.east_data, |
| ) |
|
|
| |
| tokenise_fn = create_tokenise_function(tokeniser) |
| tokenised_train = train_dataset.map( |
| tokenise_fn, |
| batched=True, |
| remove_columns=train_dataset.column_names, |
| desc="Tokenising train set", |
| ) |
| tokenised_eval = val_dataset.map( |
| tokenise_fn, |
| batched=True, |
| remove_columns=val_dataset.column_names, |
| desc="Tokenising eval set", |
| ) |
|
|
| |
| data_collator = DataCollatorForSeq2Seq(tokenizer=tokeniser, model=model) |
|
|
| |
| |
| grad_accum = 4 if args.stage <= 2 else 8 |
| label_smooth = 0.05 if args.stage == 1 else (0.08 if args.stage <= 3 else 0.1) |
| warmup = 0.10 if args.stage == 1 else 0.06 |
|
|
| training_args = Seq2SeqTrainingArguments( |
| output_dir=args.output_dir, |
| overwrite_output_dir=True, |
| num_train_epochs=stage_config["num_epochs"], |
| per_device_train_batch_size=args.batch_size, |
| per_device_eval_batch_size=args.batch_size, |
| gradient_accumulation_steps=grad_accum, |
| learning_rate=stage_config["learning_rate"], |
| warmup_ratio=warmup, |
| weight_decay=0.01, |
| label_smoothing_factor=label_smooth, |
| save_strategy="epoch", |
| save_total_limit=3, |
| eval_strategy="epoch", |
| logging_dir="logs", |
| fp16=torch.cuda.is_available(), |
| load_best_model_at_end=True, |
| metric_for_best_model="eval_loss", |
| greater_is_better=False, |
| report_to="none", |
| predict_with_generate=True, |
| generation_max_length=256, |
| ) |
|
|
| |
| |
| model.generation_config.max_length = 256 |
| |
| |
| model.generation_config.eos_token_id = tokeniser.eos_token_id |
| model.generation_config.pad_token_id = tokeniser.pad_token_id |
| |
| |
| model.generation_config.min_length = 2 if args.stage < 5 else 3 |
| |
| |
| |
| if args.stage == 1: |
| model.generation_config.num_beams = 2 |
| model.generation_config.length_penalty = 1.05 |
| elif args.stage == 2: |
| model.generation_config.num_beams = 2 |
| model.generation_config.length_penalty = 1.12 |
| elif args.stage == 3: |
| model.generation_config.num_beams = 3 |
| model.generation_config.length_penalty = 1.18 |
| elif args.stage == 4: |
| model.generation_config.num_beams = 4 |
| model.generation_config.length_penalty = 1.22 |
| else: |
| model.generation_config.num_beams = 4 |
| model.generation_config.length_penalty = ( |
| 1.25 |
| ) |
| model.generation_config.early_stopping = True |
|
|
| |
| callbacks = [] |
| if not args.no_early_stopping: |
| callbacks.append( |
| EarlyStoppingCallback( |
| early_stopping_patience=EARLY_STOPPING_PATIENCE, |
| early_stopping_threshold=EARLY_STOPPING_THRESHOLD, |
| ) |
| ) |
| print(f"\nEarly stopping enabled:") |
| print(f" Patience: {EARLY_STOPPING_PATIENCE} evaluations") |
| print(f" Threshold: {EARLY_STOPPING_THRESHOLD * 100:.1f}% improvement") |
|
|
| |
| trainer = Seq2SeqTrainer( |
| model=model, |
| args=training_args, |
| train_dataset=tokenised_train, |
| eval_dataset=tokenised_eval, |
| data_collator=data_collator, |
| processing_class=tokeniser, |
| compute_metrics=create_compute_metrics(tokeniser), |
| callbacks=callbacks, |
| ) |
|
|
| |
| print(f"\n{'=' * 60}") |
| print("Starting training...") |
| print(f"{'=' * 60}\n") |
|
|
| |
| resume_from_checkpoint = None |
| if args.resume: |
| if args.resume == "auto": |
| |
| resume_from_checkpoint = True |
| print("Resuming from last checkpoint (auto-detect)...") |
| else: |
| |
| resume_from_checkpoint = args.resume |
| print(f"Resuming from checkpoint: {args.resume}") |
|
|
| trainer.train(resume_from_checkpoint=resume_from_checkpoint) |
|
|
| |
| final_output_dir = f"{args.output_dir}/stage{args.stage}-final" |
| model.save_pretrained(final_output_dir) |
| tokeniser.save_pretrained(final_output_dir) |
|
|
| print(f"\n{'=' * 60}") |
| print(f"Stage {args.stage} training complete!") |
| print(f"Model saved to: {final_output_dir}") |
| print(f"{'=' * 60}") |
|
|
| |
| try: |
| from update_readme_metrics import (extract_metrics, |
| find_best_checkpoint, |
| update_readme_metrics) |
|
|
| checkpoint_dir = find_best_checkpoint(Path(args.output_dir)) |
| if checkpoint_dir: |
| metrics = extract_metrics(checkpoint_dir) |
| readme_path = _PROJECT_ROOT / "README.md" |
| if update_readme_metrics(readme_path, metrics): |
| print( |
| f"\nREADME metrics updated automatically from {checkpoint_dir.name}" |
| ) |
| except Exception as e: |
| print(f"\nNote: Could not auto-update README metrics: {e}") |
|
|
| |
| if args.stage <= 6: |
| print(f"\nNext steps:") |
| print( |
| f" 1. Upload model to HuggingFace (e.g., 'your-username/aramt5-v{args.stage}')" |
| ) |
| print(f" 2. Run stage {args.stage + 1}:") |
| print( |
| f" python src/train_t5.py --stage {args.stage + 1} --hf-model your-username/aramt5-v{args.stage}" |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| args = parse_args() |
| train(args) |
|
|