WJAD / scripts /update_deps.py
fuzirui's picture
Sync WJAD codebase
0cfefd2 verified
"""自动把项目依赖升级到 PyPI 最新版。
特点:
- 从 ``pyproject.toml`` 读取 ``project.dependencies`` 与
``project.optional-dependencies``;
- 直接调用 ``pip install --upgrade <pkg>`` 把所有第三方依赖升级;
- 为 ``torch`` / ``torchvision`` / ``torchaudio`` 提供单独的 CUDA index URL
选项(``--torch-index https://download.pytorch.org/whl/cu124``);
- 升级后调用 ``pip freeze`` 把锁定版本写入 ``requirements.lock.txt``,便于
在 HF Sandbox / Jobs 环境中复现。
注意:本脚本会修改本地 venv!若需要安全演练,加 ``--dry-run``。
"""
from __future__ import annotations
import argparse
import shutil
import subprocess
import sys
from pathlib import Path
try:
import tomllib # py3.11+
except ImportError: # pragma: no cover
import tomli as tomllib # type: ignore
ROOT = Path(__file__).resolve().parent.parent
PYPROJECT = ROOT / "pyproject.toml"
LOCK_FILE = ROOT / "requirements.lock.txt"
TORCH_PKGS = {"torch", "torchvision", "torchaudio"}
def parse_pyproject() -> tuple[list[str], list[str]]:
"""返回 (主依赖, dev 依赖) 的纯包名列表。"""
data = tomllib.loads(PYPROJECT.read_text(encoding="utf-8"))
main = [
_strip_spec(d) for d in data.get("project", {}).get("dependencies", [])
]
dev = [
_strip_spec(d)
for d in data.get("project", {})
.get("optional-dependencies", {})
.get("dev", [])
]
return main, dev
def _strip_spec(req: str) -> str:
"""去掉版本约束,只留包名。"""
name = req.split(";")[0] # 去掉 environment marker
for sym in ("[", ">=", "<=", "==", "~=", ">", "<", "!=", "@"):
if sym in name:
name = name.split(sym)[0]
return name.strip()
def run(cmd: list[str], dry_run: bool = False) -> int:
print("$", " ".join(cmd))
if dry_run:
return 0
return subprocess.call(cmd)
def upgrade(pkgs: list[str], extra_index: str | None, dry_run: bool, with_pre: bool = False) -> None:
base = [sys.executable, "-m", "pip", "install", "--upgrade"]
if with_pre:
base.append("--pre")
if extra_index:
base += ["--extra-index-url", extra_index]
rc = run(base + pkgs, dry_run=dry_run)
if rc != 0:
sys.exit(rc)
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
"--torch-index",
default=None,
help="PyTorch CUDA wheel 索引(如 https://download.pytorch.org/whl/cu124)",
)
parser.add_argument("--no-dev", action="store_true", help="不升级 dev 依赖")
parser.add_argument("--dry-run", action="store_true", help="只打印命令不执行")
parser.add_argument("--with-pre", action="store_true", help="允许升级到 pre-release")
args = parser.parse_args()
if not PYPROJECT.exists():
print(f"[update_deps] 找不到 {PYPROJECT}", file=sys.stderr)
sys.exit(1)
main_deps, dev_deps = parse_pyproject()
# 把 torch 系列单独处理(用专用索引)
torch_deps = [p for p in main_deps if p in TORCH_PKGS]
other_deps = [p for p in main_deps if p not in TORCH_PKGS]
print(f"[update_deps] 升级 pip / setuptools / wheel ...")
upgrade(["pip", "setuptools", "wheel"], extra_index=None, dry_run=args.dry_run)
if torch_deps:
print(f"[update_deps] 升级 torch 系列 ({torch_deps}) ...")
upgrade(torch_deps, extra_index=args.torch_index, dry_run=args.dry_run, with_pre=args.with_pre)
if other_deps:
print(f"[update_deps] 升级主依赖 ({len(other_deps)} 个) ...")
upgrade(other_deps, extra_index=None, dry_run=args.dry_run, with_pre=args.with_pre)
if dev_deps and not args.no_dev:
print(f"[update_deps] 升级 dev 依赖 ({len(dev_deps)} 个) ...")
upgrade(dev_deps, extra_index=None, dry_run=args.dry_run, with_pre=args.with_pre)
print("[update_deps] 写入锁定文件 ...")
if not args.dry_run:
with open(LOCK_FILE, "w", encoding="utf-8") as f:
subprocess.run([sys.executable, "-m", "pip", "freeze"], stdout=f, check=True)
print(f"[update_deps] 锁定版本已写入 {LOCK_FILE}")
print("[update_deps] OK")
if __name__ == "__main__":
main()