cxr-vlm-code / utils /_torch_load_compat.py
convitom
f
c9b4129
"""
_torch_load_compat.py
---------------------
PyTorch ≥2.6 flipped the default of `torch.load(..., weights_only=...)` from
False to True. transformers ≤4.50 still calls `torch.load(rng_file)` (no
explicit kwarg) inside `Trainer._load_rng_state()` when resuming from a
checkpoint — the rng_state.pth file contains numpy arrays, so the safe
unpickler rejects it with:
_pickle.UnpicklingError: Weights only load failed.
WeightsUnpickler error: Unsupported global: numpy._core.multiarray._reconstruct
This triggers on Lightning AI (torch 2.6+ in the default `cloudspace` env)
the moment training tries to resume. Fresh runs don't hit it.
The fix is to monkey-patch torch.load so it defaults to weights_only=False
*only when the caller hasn't explicitly opted in*. Modern code that passes
`weights_only=True` (e.g. transformers ≥4.50's safetensors path) is
unaffected.
Imported eagerly by `utils/_quiet.py` so the patch is live in every
train.py / evaluate.py subprocess before `transformers` runs.
Idempotent: calling apply() twice is a no-op. Safe on torch <2.6 too —
just returns False without patching.
"""
def apply() -> bool:
"""Install the weights_only-default shim. Returns True iff patched."""
try:
import torch
except ImportError:
return False
try:
ver = tuple(int(x) for x in torch.__version__.split(".")[:2])
except Exception:
return False
if ver < (2, 6):
return False
if getattr(torch.load, "_cxr_vlm_compat_patched", False):
return True
_orig_load = torch.load
def patched_load(*args, **kwargs):
# Default to legacy behaviour (weights_only=False) ONLY when the
# caller didn't pick a value. Code that explicitly sets
# weights_only=True keeps the safe path.
kwargs.setdefault("weights_only", False)
return _orig_load(*args, **kwargs)
patched_load._cxr_vlm_compat_patched = True
patched_load._cxr_vlm_orig = _orig_load
torch.load = patched_load
return True
# Apply on import.
_PATCHED = apply()