File size: 2,078 Bytes
c9b4129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
"""
_torch_load_compat.py
---------------------
PyTorch ≥2.6 flipped the default of `torch.load(..., weights_only=...)` from
False to True. transformers ≤4.50 still calls `torch.load(rng_file)` (no
explicit kwarg) inside `Trainer._load_rng_state()` when resuming from a
checkpoint — the rng_state.pth file contains numpy arrays, so the safe
unpickler rejects it with:

    _pickle.UnpicklingError: Weights only load failed.
    WeightsUnpickler error: Unsupported global: numpy._core.multiarray._reconstruct

This triggers on Lightning AI (torch 2.6+ in the default `cloudspace` env)
the moment training tries to resume. Fresh runs don't hit it.

The fix is to monkey-patch torch.load so it defaults to weights_only=False
*only when the caller hasn't explicitly opted in*. Modern code that passes
`weights_only=True` (e.g. transformers ≥4.50's safetensors path) is
unaffected.

Imported eagerly by `utils/_quiet.py` so the patch is live in every
train.py / evaluate.py subprocess before `transformers` runs.

Idempotent: calling apply() twice is a no-op. Safe on torch <2.6 too —
just returns False without patching.
"""


def apply() -> bool:
    """Install the weights_only-default shim. Returns True iff patched."""
    try:
        import torch
    except ImportError:
        return False

    try:
        ver = tuple(int(x) for x in torch.__version__.split(".")[:2])
    except Exception:
        return False
    if ver < (2, 6):
        return False

    if getattr(torch.load, "_cxr_vlm_compat_patched", False):
        return True

    _orig_load = torch.load

    def patched_load(*args, **kwargs):
        # Default to legacy behaviour (weights_only=False) ONLY when the
        # caller didn't pick a value. Code that explicitly sets
        # weights_only=True keeps the safe path.
        kwargs.setdefault("weights_only", False)
        return _orig_load(*args, **kwargs)

    patched_load._cxr_vlm_compat_patched = True
    patched_load._cxr_vlm_orig = _orig_load
    torch.load = patched_load
    return True


# Apply on import.
_PATCHED = apply()