| """自动把项目依赖升级到 PyPI 最新版。 |
| |
| 特点: |
| - 从 ``pyproject.toml`` 读取 ``project.dependencies`` 与 |
| ``project.optional-dependencies``; |
| - 直接调用 ``pip install --upgrade <pkg>`` 把所有第三方依赖升级; |
| - 为 ``torch`` / ``torchvision`` / ``torchaudio`` 提供单独的 CUDA index URL |
| 选项(``--torch-index https://download.pytorch.org/whl/cu124``); |
| - 升级后调用 ``pip freeze`` 把锁定版本写入 ``requirements.lock.txt``,便于 |
| 在 HF Sandbox / Jobs 环境中复现。 |
| |
| 注意:本脚本会修改本地 venv!若需要安全演练,加 ``--dry-run``。 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import shutil |
| import subprocess |
| import sys |
| from pathlib import Path |
|
|
| try: |
| import tomllib |
| except ImportError: |
| import tomli as tomllib |
|
|
|
|
| ROOT = Path(__file__).resolve().parent.parent |
| PYPROJECT = ROOT / "pyproject.toml" |
| LOCK_FILE = ROOT / "requirements.lock.txt" |
| TORCH_PKGS = {"torch", "torchvision", "torchaudio"} |
|
|
|
|
| def parse_pyproject() -> tuple[list[str], list[str]]: |
| """返回 (主依赖, dev 依赖) 的纯包名列表。""" |
| data = tomllib.loads(PYPROJECT.read_text(encoding="utf-8")) |
| main = [ |
| _strip_spec(d) for d in data.get("project", {}).get("dependencies", []) |
| ] |
| dev = [ |
| _strip_spec(d) |
| for d in data.get("project", {}) |
| .get("optional-dependencies", {}) |
| .get("dev", []) |
| ] |
| return main, dev |
|
|
|
|
| def _strip_spec(req: str) -> str: |
| """去掉版本约束,只留包名。""" |
| name = req.split(";")[0] |
| for sym in ("[", ">=", "<=", "==", "~=", ">", "<", "!=", "@"): |
| if sym in name: |
| name = name.split(sym)[0] |
| return name.strip() |
|
|
|
|
| def run(cmd: list[str], dry_run: bool = False) -> int: |
| print("$", " ".join(cmd)) |
| if dry_run: |
| return 0 |
| return subprocess.call(cmd) |
|
|
|
|
| def upgrade(pkgs: list[str], extra_index: str | None, dry_run: bool, with_pre: bool = False) -> None: |
| base = [sys.executable, "-m", "pip", "install", "--upgrade"] |
| if with_pre: |
| base.append("--pre") |
| if extra_index: |
| base += ["--extra-index-url", extra_index] |
| rc = run(base + pkgs, dry_run=dry_run) |
| if rc != 0: |
| sys.exit(rc) |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "--torch-index", |
| default=None, |
| help="PyTorch CUDA wheel 索引(如 https://download.pytorch.org/whl/cu124)", |
| ) |
| parser.add_argument("--no-dev", action="store_true", help="不升级 dev 依赖") |
| parser.add_argument("--dry-run", action="store_true", help="只打印命令不执行") |
| parser.add_argument("--with-pre", action="store_true", help="允许升级到 pre-release") |
| args = parser.parse_args() |
|
|
| if not PYPROJECT.exists(): |
| print(f"[update_deps] 找不到 {PYPROJECT}", file=sys.stderr) |
| sys.exit(1) |
|
|
| main_deps, dev_deps = parse_pyproject() |
|
|
| |
| torch_deps = [p for p in main_deps if p in TORCH_PKGS] |
| other_deps = [p for p in main_deps if p not in TORCH_PKGS] |
|
|
| print(f"[update_deps] 升级 pip / setuptools / wheel ...") |
| upgrade(["pip", "setuptools", "wheel"], extra_index=None, dry_run=args.dry_run) |
|
|
| if torch_deps: |
| print(f"[update_deps] 升级 torch 系列 ({torch_deps}) ...") |
| upgrade(torch_deps, extra_index=args.torch_index, dry_run=args.dry_run, with_pre=args.with_pre) |
|
|
| if other_deps: |
| print(f"[update_deps] 升级主依赖 ({len(other_deps)} 个) ...") |
| upgrade(other_deps, extra_index=None, dry_run=args.dry_run, with_pre=args.with_pre) |
|
|
| if dev_deps and not args.no_dev: |
| print(f"[update_deps] 升级 dev 依赖 ({len(dev_deps)} 个) ...") |
| upgrade(dev_deps, extra_index=None, dry_run=args.dry_run, with_pre=args.with_pre) |
|
|
| print("[update_deps] 写入锁定文件 ...") |
| if not args.dry_run: |
| with open(LOCK_FILE, "w", encoding="utf-8") as f: |
| subprocess.run([sys.executable, "-m", "pip", "freeze"], stdout=f, check=True) |
| print(f"[update_deps] 锁定版本已写入 {LOCK_FILE}") |
| print("[update_deps] OK") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|