File size: 4,296 Bytes
0cfefd2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""自动把项目依赖升级到 PyPI 最新版。

特点:
  - 从 ``pyproject.toml`` 读取 ``project.dependencies`` 与
    ``project.optional-dependencies``;
  - 直接调用 ``pip install --upgrade <pkg>`` 把所有第三方依赖升级;
  - 为 ``torch`` / ``torchvision`` / ``torchaudio`` 提供单独的 CUDA index URL
    选项(``--torch-index https://download.pytorch.org/whl/cu124``);
  - 升级后调用 ``pip freeze`` 把锁定版本写入 ``requirements.lock.txt``,便于
    在 HF Sandbox / Jobs 环境中复现。

注意:本脚本会修改本地 venv!若需要安全演练,加 ``--dry-run``。
"""

from __future__ import annotations

import argparse
import shutil
import subprocess
import sys
from pathlib import Path

try:
    import tomllib  # py3.11+
except ImportError:  # pragma: no cover
    import tomli as tomllib  # type: ignore


ROOT = Path(__file__).resolve().parent.parent
PYPROJECT = ROOT / "pyproject.toml"
LOCK_FILE = ROOT / "requirements.lock.txt"
TORCH_PKGS = {"torch", "torchvision", "torchaudio"}


def parse_pyproject() -> tuple[list[str], list[str]]:
    """返回 (主依赖, dev 依赖) 的纯包名列表。"""
    data = tomllib.loads(PYPROJECT.read_text(encoding="utf-8"))
    main = [
        _strip_spec(d) for d in data.get("project", {}).get("dependencies", [])
    ]
    dev = [
        _strip_spec(d)
        for d in data.get("project", {})
        .get("optional-dependencies", {})
        .get("dev", [])
    ]
    return main, dev


def _strip_spec(req: str) -> str:
    """去掉版本约束,只留包名。"""
    name = req.split(";")[0]  # 去掉 environment marker
    for sym in ("[", ">=", "<=", "==", "~=", ">", "<", "!=", "@"):
        if sym in name:
            name = name.split(sym)[0]
    return name.strip()


def run(cmd: list[str], dry_run: bool = False) -> int:
    print("$", " ".join(cmd))
    if dry_run:
        return 0
    return subprocess.call(cmd)


def upgrade(pkgs: list[str], extra_index: str | None, dry_run: bool, with_pre: bool = False) -> None:
    base = [sys.executable, "-m", "pip", "install", "--upgrade"]
    if with_pre:
        base.append("--pre")
    if extra_index:
        base += ["--extra-index-url", extra_index]
    rc = run(base + pkgs, dry_run=dry_run)
    if rc != 0:
        sys.exit(rc)


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--torch-index",
        default=None,
        help="PyTorch CUDA wheel 索引(如 https://download.pytorch.org/whl/cu124)",
    )
    parser.add_argument("--no-dev", action="store_true", help="不升级 dev 依赖")
    parser.add_argument("--dry-run", action="store_true", help="只打印命令不执行")
    parser.add_argument("--with-pre", action="store_true", help="允许升级到 pre-release")
    args = parser.parse_args()

    if not PYPROJECT.exists():
        print(f"[update_deps] 找不到 {PYPROJECT}", file=sys.stderr)
        sys.exit(1)

    main_deps, dev_deps = parse_pyproject()

    # 把 torch 系列单独处理(用专用索引)
    torch_deps = [p for p in main_deps if p in TORCH_PKGS]
    other_deps = [p for p in main_deps if p not in TORCH_PKGS]

    print(f"[update_deps] 升级 pip / setuptools / wheel ...")
    upgrade(["pip", "setuptools", "wheel"], extra_index=None, dry_run=args.dry_run)

    if torch_deps:
        print(f"[update_deps] 升级 torch 系列 ({torch_deps}) ...")
        upgrade(torch_deps, extra_index=args.torch_index, dry_run=args.dry_run, with_pre=args.with_pre)

    if other_deps:
        print(f"[update_deps] 升级主依赖 ({len(other_deps)} 个) ...")
        upgrade(other_deps, extra_index=None, dry_run=args.dry_run, with_pre=args.with_pre)

    if dev_deps and not args.no_dev:
        print(f"[update_deps] 升级 dev 依赖 ({len(dev_deps)} 个) ...")
        upgrade(dev_deps, extra_index=None, dry_run=args.dry_run, with_pre=args.with_pre)

    print("[update_deps] 写入锁定文件 ...")
    if not args.dry_run:
        with open(LOCK_FILE, "w", encoding="utf-8") as f:
            subprocess.run([sys.executable, "-m", "pip", "freeze"], stdout=f, check=True)
        print(f"[update_deps] 锁定版本已写入 {LOCK_FILE}")
    print("[update_deps] OK")


if __name__ == "__main__":
    main()