Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| JoeyNMT Wrapper β patches for PyTorch 2.7+ compatibility. | |
| PyTorch 2.7 removed the `verbose` kwarg from ReduceLROnPlateau. | |
| JoeyNMT 2.3.0 still passes it, so we monkey-patch the scheduler | |
| before importing JoeyNMT. | |
| On Linux this also patches symlink handling to work around any | |
| permission issues. | |
| Usage (called by run.py, not directly): | |
| python joeynmt_wrapper.py train config.yaml | |
| python joeynmt_wrapper.py test config.yaml --ckpt best.ckpt --output_path out.txt | |
| """ | |
| import sys | |
| import os | |
| import shutil | |
| from pathlib import Path | |
| from typing import Optional | |
| # --- Patch ReduceLROnPlateau BEFORE importing joeynmt --- | |
| import torch.optim.lr_scheduler as _lr_sched | |
| _OrigPlateau = _lr_sched.ReduceLROnPlateau | |
| class _PatchedPlateau(_OrigPlateau): | |
| """Accept and silently ignore the `verbose` kwarg removed in PyTorch 2.7.""" | |
| def __init__(self, *args, **kwargs): | |
| kwargs.pop("verbose", None) | |
| super().__init__(*args, **kwargs) | |
| _lr_sched.ReduceLROnPlateau = _PatchedPlateau | |
| # --- Now safe to import joeynmt --- | |
| import joeynmt.helpers | |
| import joeynmt.training | |
| # --------------------------------------------------------------------------- | |
| # Symlink patch β use file copy instead of symlinks (works everywhere) | |
| # --------------------------------------------------------------------------- | |
| def _safe_symlink_update(target: Path, link_name: Path) -> Optional[Path]: | |
| """ | |
| Replacement for joeynmt.helpers.symlink_update that uses file copying | |
| instead of OS symlinks (avoids permission issues). | |
| """ | |
| current_last = None | |
| sidecar = link_name.parent / (link_name.name + ".target") | |
| if link_name.exists(): | |
| if sidecar.exists(): | |
| prev_target = sidecar.read_text(encoding="utf-8").strip() | |
| current_last = link_name.parent / prev_target | |
| link_name.unlink(missing_ok=True) | |
| actual = link_name.parent / target | |
| if actual.exists(): | |
| shutil.copy2(str(actual), str(link_name)) | |
| sidecar.write_text(str(target), encoding="utf-8") | |
| return current_last | |
| joeynmt.helpers.symlink_update = _safe_symlink_update | |
| # --------------------------------------------------------------------------- | |
| # Checkpoint-save patch β skip symlink assertion | |
| # --------------------------------------------------------------------------- | |
| _orig_save = joeynmt.training.TrainManager._save_checkpoint | |
| def _patched_save(self, new_best, score): | |
| """Save checkpoint without symlink-resolve assertion.""" | |
| import math | |
| import heapq | |
| import logging | |
| import torch | |
| logger = logging.getLogger(__name__) | |
| model_path = Path(self.model_dir) / f"{self.stats.steps}.ckpt" | |
| state = { | |
| "steps": self.stats.steps, | |
| "model_state": self.model.state_dict(), | |
| "optimizer_state": self.optimizer.state_dict(), | |
| "scheduler_state": ( | |
| self.scheduler.state_dict() if self.scheduler is not None else None | |
| ), | |
| "scaler_state": ( | |
| self.scaler.state_dict() if self.scaler is not None else None | |
| ), | |
| } | |
| torch.save(state, model_path.as_posix()) | |
| logger.info("Checkpoint saved in %s.", model_path) | |
| symlink_target = Path(f"{self.stats.steps}.ckpt") | |
| last_path = Path(self.model_dir) / "latest.ckpt" | |
| prev_path = _safe_symlink_update(symlink_target, last_path) | |
| best_path = Path(self.model_dir) / "best.ckpt" | |
| if new_best: | |
| prev_path = _safe_symlink_update(symlink_target, best_path) | |
| to_delete = None | |
| if not math.isnan(score) and self.args.keep_best_ckpts > 0: | |
| if len(self.ckpt_queue) < self.args.keep_best_ckpts: | |
| heapq.heappush(self.ckpt_queue, (score, model_path)) | |
| else: | |
| if self.args.minimize_metric: | |
| heapq._heapify_max(self.ckpt_queue) | |
| to_delete = heapq._heappop_max(self.ckpt_queue) | |
| heapq.heappush(self.ckpt_queue, (score, model_path)) | |
| else: | |
| to_delete = heapq.heapreplace( | |
| self.ckpt_queue, (score, model_path)) | |
| if to_delete is not None: | |
| _, path_to_delete = to_delete | |
| if Path(path_to_delete).exists(): | |
| Path(path_to_delete).unlink() | |
| logger.info("Removed old ckpt: %s", path_to_delete) | |
| if prev_path is not None and prev_path.exists(): | |
| if prev_path != model_path: | |
| still_needed = any(p == prev_path for _, p in self.ckpt_queue) | |
| if not still_needed: | |
| pass # could delete, but safer to keep | |
| joeynmt.training.TrainManager._save_checkpoint = _patched_save | |
| # --------------------------------------------------------------------------- | |
| # Entry point β delegates to joeynmt CLI | |
| # --------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| from joeynmt.__main__ import main | |
| main() | |