"""自动把项目依赖升级到 PyPI 最新版。 特点: - 从 ``pyproject.toml`` 读取 ``project.dependencies`` 与 ``project.optional-dependencies``; - 直接调用 ``pip install --upgrade `` 把所有第三方依赖升级; - 为 ``torch`` / ``torchvision`` / ``torchaudio`` 提供单独的 CUDA index URL 选项(``--torch-index https://download.pytorch.org/whl/cu124``); - 升级后调用 ``pip freeze`` 把锁定版本写入 ``requirements.lock.txt``,便于 在 HF Sandbox / Jobs 环境中复现。 注意:本脚本会修改本地 venv!若需要安全演练,加 ``--dry-run``。 """ from __future__ import annotations import argparse import shutil import subprocess import sys from pathlib import Path try: import tomllib # py3.11+ except ImportError: # pragma: no cover import tomli as tomllib # type: ignore ROOT = Path(__file__).resolve().parent.parent PYPROJECT = ROOT / "pyproject.toml" LOCK_FILE = ROOT / "requirements.lock.txt" TORCH_PKGS = {"torch", "torchvision", "torchaudio"} def parse_pyproject() -> tuple[list[str], list[str]]: """返回 (主依赖, dev 依赖) 的纯包名列表。""" data = tomllib.loads(PYPROJECT.read_text(encoding="utf-8")) main = [ _strip_spec(d) for d in data.get("project", {}).get("dependencies", []) ] dev = [ _strip_spec(d) for d in data.get("project", {}) .get("optional-dependencies", {}) .get("dev", []) ] return main, dev def _strip_spec(req: str) -> str: """去掉版本约束,只留包名。""" name = req.split(";")[0] # 去掉 environment marker for sym in ("[", ">=", "<=", "==", "~=", ">", "<", "!=", "@"): if sym in name: name = name.split(sym)[0] return name.strip() def run(cmd: list[str], dry_run: bool = False) -> int: print("$", " ".join(cmd)) if dry_run: return 0 return subprocess.call(cmd) def upgrade(pkgs: list[str], extra_index: str | None, dry_run: bool, with_pre: bool = False) -> None: base = [sys.executable, "-m", "pip", "install", "--upgrade"] if with_pre: base.append("--pre") if extra_index: base += ["--extra-index-url", extra_index] rc = run(base + pkgs, dry_run=dry_run) if rc != 0: sys.exit(rc) def main() -> None: parser = argparse.ArgumentParser() parser.add_argument( "--torch-index", default=None, help="PyTorch CUDA wheel 索引(如 https://download.pytorch.org/whl/cu124)", ) parser.add_argument("--no-dev", action="store_true", help="不升级 dev 依赖") parser.add_argument("--dry-run", action="store_true", help="只打印命令不执行") parser.add_argument("--with-pre", action="store_true", help="允许升级到 pre-release") args = parser.parse_args() if not PYPROJECT.exists(): print(f"[update_deps] 找不到 {PYPROJECT}", file=sys.stderr) sys.exit(1) main_deps, dev_deps = parse_pyproject() # 把 torch 系列单独处理(用专用索引) torch_deps = [p for p in main_deps if p in TORCH_PKGS] other_deps = [p for p in main_deps if p not in TORCH_PKGS] print(f"[update_deps] 升级 pip / setuptools / wheel ...") upgrade(["pip", "setuptools", "wheel"], extra_index=None, dry_run=args.dry_run) if torch_deps: print(f"[update_deps] 升级 torch 系列 ({torch_deps}) ...") upgrade(torch_deps, extra_index=args.torch_index, dry_run=args.dry_run, with_pre=args.with_pre) if other_deps: print(f"[update_deps] 升级主依赖 ({len(other_deps)} 个) ...") upgrade(other_deps, extra_index=None, dry_run=args.dry_run, with_pre=args.with_pre) if dev_deps and not args.no_dev: print(f"[update_deps] 升级 dev 依赖 ({len(dev_deps)} 个) ...") upgrade(dev_deps, extra_index=None, dry_run=args.dry_run, with_pre=args.with_pre) print("[update_deps] 写入锁定文件 ...") if not args.dry_run: with open(LOCK_FILE, "w", encoding="utf-8") as f: subprocess.run([sys.executable, "-m", "pip", "freeze"], stdout=f, check=True) print(f"[update_deps] 锁定版本已写入 {LOCK_FILE}") print("[update_deps] OK") if __name__ == "__main__": main()