| """ |
| _torch_load_compat.py |
| --------------------- |
| PyTorch ≥2.6 flipped the default of `torch.load(..., weights_only=...)` from |
| False to True. transformers ≤4.50 still calls `torch.load(rng_file)` (no |
| explicit kwarg) inside `Trainer._load_rng_state()` when resuming from a |
| checkpoint — the rng_state.pth file contains numpy arrays, so the safe |
| unpickler rejects it with: |
| |
| _pickle.UnpicklingError: Weights only load failed. |
| WeightsUnpickler error: Unsupported global: numpy._core.multiarray._reconstruct |
| |
| This triggers on Lightning AI (torch 2.6+ in the default `cloudspace` env) |
| the moment training tries to resume. Fresh runs don't hit it. |
| |
| The fix is to monkey-patch torch.load so it defaults to weights_only=False |
| *only when the caller hasn't explicitly opted in*. Modern code that passes |
| `weights_only=True` (e.g. transformers ≥4.50's safetensors path) is |
| unaffected. |
| |
| Imported eagerly by `utils/_quiet.py` so the patch is live in every |
| train.py / evaluate.py subprocess before `transformers` runs. |
| |
| Idempotent: calling apply() twice is a no-op. Safe on torch <2.6 too — |
| just returns False without patching. |
| """ |
|
|
|
|
| def apply() -> bool: |
| """Install the weights_only-default shim. Returns True iff patched.""" |
| try: |
| import torch |
| except ImportError: |
| return False |
|
|
| try: |
| ver = tuple(int(x) for x in torch.__version__.split(".")[:2]) |
| except Exception: |
| return False |
| if ver < (2, 6): |
| return False |
|
|
| if getattr(torch.load, "_cxr_vlm_compat_patched", False): |
| return True |
|
|
| _orig_load = torch.load |
|
|
| def patched_load(*args, **kwargs): |
| |
| |
| |
| kwargs.setdefault("weights_only", False) |
| return _orig_load(*args, **kwargs) |
|
|
| patched_load._cxr_vlm_compat_patched = True |
| patched_load._cxr_vlm_orig = _orig_load |
| torch.load = patched_load |
| return True |
|
|
|
|
| |
| _PATCHED = apply() |
|
|