nmt-translator / joeynmt_wrapper.py
Soha368's picture
NMT Translator - English to Spanish
a888ccc
#!/usr/bin/env python3
"""
JoeyNMT Wrapper β€” patches for PyTorch 2.7+ compatibility.
PyTorch 2.7 removed the `verbose` kwarg from ReduceLROnPlateau.
JoeyNMT 2.3.0 still passes it, so we monkey-patch the scheduler
before importing JoeyNMT.
On Linux this also patches symlink handling to work around any
permission issues.
Usage (called by run.py, not directly):
python joeynmt_wrapper.py train config.yaml
python joeynmt_wrapper.py test config.yaml --ckpt best.ckpt --output_path out.txt
"""
import sys
import os
import shutil
from pathlib import Path
from typing import Optional
# --- Patch ReduceLROnPlateau BEFORE importing joeynmt ---
import torch.optim.lr_scheduler as _lr_sched
_OrigPlateau = _lr_sched.ReduceLROnPlateau
class _PatchedPlateau(_OrigPlateau):
"""Accept and silently ignore the `verbose` kwarg removed in PyTorch 2.7."""
def __init__(self, *args, **kwargs):
kwargs.pop("verbose", None)
super().__init__(*args, **kwargs)
_lr_sched.ReduceLROnPlateau = _PatchedPlateau
# --- Now safe to import joeynmt ---
import joeynmt.helpers
import joeynmt.training
# ---------------------------------------------------------------------------
# Symlink patch β€” use file copy instead of symlinks (works everywhere)
# ---------------------------------------------------------------------------
def _safe_symlink_update(target: Path, link_name: Path) -> Optional[Path]:
"""
Replacement for joeynmt.helpers.symlink_update that uses file copying
instead of OS symlinks (avoids permission issues).
"""
current_last = None
sidecar = link_name.parent / (link_name.name + ".target")
if link_name.exists():
if sidecar.exists():
prev_target = sidecar.read_text(encoding="utf-8").strip()
current_last = link_name.parent / prev_target
link_name.unlink(missing_ok=True)
actual = link_name.parent / target
if actual.exists():
shutil.copy2(str(actual), str(link_name))
sidecar.write_text(str(target), encoding="utf-8")
return current_last
joeynmt.helpers.symlink_update = _safe_symlink_update
# ---------------------------------------------------------------------------
# Checkpoint-save patch β€” skip symlink assertion
# ---------------------------------------------------------------------------
_orig_save = joeynmt.training.TrainManager._save_checkpoint
def _patched_save(self, new_best, score):
"""Save checkpoint without symlink-resolve assertion."""
import math
import heapq
import logging
import torch
logger = logging.getLogger(__name__)
model_path = Path(self.model_dir) / f"{self.stats.steps}.ckpt"
state = {
"steps": self.stats.steps,
"model_state": self.model.state_dict(),
"optimizer_state": self.optimizer.state_dict(),
"scheduler_state": (
self.scheduler.state_dict() if self.scheduler is not None else None
),
"scaler_state": (
self.scaler.state_dict() if self.scaler is not None else None
),
}
torch.save(state, model_path.as_posix())
logger.info("Checkpoint saved in %s.", model_path)
symlink_target = Path(f"{self.stats.steps}.ckpt")
last_path = Path(self.model_dir) / "latest.ckpt"
prev_path = _safe_symlink_update(symlink_target, last_path)
best_path = Path(self.model_dir) / "best.ckpt"
if new_best:
prev_path = _safe_symlink_update(symlink_target, best_path)
to_delete = None
if not math.isnan(score) and self.args.keep_best_ckpts > 0:
if len(self.ckpt_queue) < self.args.keep_best_ckpts:
heapq.heappush(self.ckpt_queue, (score, model_path))
else:
if self.args.minimize_metric:
heapq._heapify_max(self.ckpt_queue)
to_delete = heapq._heappop_max(self.ckpt_queue)
heapq.heappush(self.ckpt_queue, (score, model_path))
else:
to_delete = heapq.heapreplace(
self.ckpt_queue, (score, model_path))
if to_delete is not None:
_, path_to_delete = to_delete
if Path(path_to_delete).exists():
Path(path_to_delete).unlink()
logger.info("Removed old ckpt: %s", path_to_delete)
if prev_path is not None and prev_path.exists():
if prev_path != model_path:
still_needed = any(p == prev_path for _, p in self.ckpt_queue)
if not still_needed:
pass # could delete, but safer to keep
joeynmt.training.TrainManager._save_checkpoint = _patched_save
# ---------------------------------------------------------------------------
# Entry point β€” delegates to joeynmt CLI
# ---------------------------------------------------------------------------
if __name__ == "__main__":
from joeynmt.__main__ import main
main()