File size: 3,862 Bytes
06883ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8356dae
 
 
 
06883ee
 
 
8356dae
 
06883ee
 
 
 
 
 
 
 
 
 
 
8356dae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
06883ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""One-shot helper: drop the stale httpx<0.28 pin from cell-pip.

huggingface_hub v1.x has migrated to httpx's `follow_redirects` API, so the
old `httpx<0.28` workaround (for hub 0.27 + httpx 0.28 incompatibility) no
longer applies. The assertion at the end of the cell now fires falsely on
modern Colab images where firebase-admin pulls httpx==0.28.1 back in.

Idempotent: re-running just replaces the cell source again.
"""
import json
from pathlib import Path

NB_PATH = Path(__file__).resolve().parent / "cxrvlm_colab_train.ipynb"


NEW_SRC = r'''!pip uninstall -y -q torchao transformers bitsandbytes peft accelerate

# Let pip pick latest bnb that matches Colab's CUDA 12.8 + triton 3.x
!pip install -q -U bitsandbytes

# Install everything. We DON'T pin httpx anymore — Colab's firebase-admin and
# google-genai hard-pin httpx==0.28.1, so the resolver always wins. Instead
# we monkey-patch httpx 0.28+ below to keep accepting the legacy
# `allow_redirects` kwarg that transformers ≤4.50 still passes.
!pip install -q \
    'transformers>=4.46,<4.50' \
    'peft>=0.13,<0.15' \
    'accelerate>=1.0' \
    'huggingface_hub>=0.27,<1.0' \
    omegaconf sentencepiece 'protobuf>=3.20' \
    nltk rouge-score bert-score sacrebleu

import torch, transformers, bitsandbytes, peft, accelerate, huggingface_hub, httpx
print('torch          :', torch.__version__, '| cuda:', torch.cuda.is_available())
print('transformers   :', transformers.__version__)
print('bitsandbytes   :', bitsandbytes.__version__)
print('peft           :', peft.__version__)
print('accelerate     :', accelerate.__version__)
print('huggingface_hub:', huggingface_hub.__version__)
print('httpx          :', httpx.__version__)

# ── httpx 0.28+ compat shim ───────────────────────────────────────────────
# transformers ≤4.50 calls httpx.Client.head(..., allow_redirects=True) which
# httpx 0.28 removed → "Client.head() got an unexpected keyword argument
# 'allow_redirects'". Translate the kwarg at the call site so the rest of
# the stack keeps working. No-op on httpx <0.28.
#
# The same patch is auto-applied inside the train.py subprocess via
# utils._quiet → utils._httpx_compat. Here we apply it in the NOTEBOOK
# kernel too, so the smoke test cell (which runs in-kernel) benefits.
def _patch_httpx():
    if tuple(int(x) for x in httpx.__version__.split('.')[:2]) < (0, 28):
        return
    if getattr(httpx.Client, '_cxr_vlm_compat_patched', False):
        return
    def _make(orig):
        def patched(self, *args, **kwargs):
            if 'allow_redirects' in kwargs:
                kwargs['follow_redirects'] = kwargs.pop('allow_redirects')
            return orig(self, *args, **kwargs)
        return patched
    for cls in (httpx.Client, httpx.AsyncClient):
        for m in ('request', 'get', 'head', 'post', 'put',
                  'patch', 'delete', 'options'):
            if hasattr(cls, m):
                setattr(cls, m, _make(getattr(cls, m)))
    httpx.Client._cxr_vlm_compat_patched = True
    print(f'httpx {httpx.__version__}: monkey-patched allow_redirects → follow_redirects')

_patch_httpx()
'''


def src_to_lines(s: str):
    lines = s.split("\n")
    return [ln + "\n" for ln in lines[:-1]] + ([lines[-1]] if lines[-1] else [])


def main():
    nb = json.loads(NB_PATH.read_text(encoding="utf-8"))
    for c in nb["cells"]:
        if c.get("id") == "cell-pip":
            c["source"] = src_to_lines(NEW_SRC)
            c["outputs"] = []
            c["execution_count"] = None
            break
    else:
        raise RuntimeError("cell-pip not found")

    NB_PATH.write_text(json.dumps(nb, indent=1, ensure_ascii=False) + "\n", encoding="utf-8")
    print(f"Patched cell-pip in {NB_PATH}")


if __name__ == "__main__":
    main()