File size: 4,936 Bytes
61ba51e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 | #!/usr/bin/env python3
import argparse
import re
import sys
from pathlib import Path
from utils import compare_versions, get_repo_root, normalize_version, validate_version
FILES_TO_UPDATE = [
Path("python/pyproject.toml"),
Path("docker/Dockerfile"),
Path("scripts/ci/cuda/ci_install_dependency.sh"),
Path("python/sglang/srt/entrypoints/engine.py"),
Path("python/sglang/srt/utils/common.py"),
]
def read_current_flashinfer_version(repo_root: Path) -> str:
"""Read the current flashinfer version from python/pyproject.toml."""
pyproject = repo_root / "python" / "pyproject.toml"
content = pyproject.read_text()
match = re.search(
r"flashinfer_python==(\d+\.\d+\.\d+(?:rc\d+|\.post\d+)?)", content
)
if not match:
raise ValueError(f"Could not find flashinfer_python version in {pyproject}")
return match.group(1)
def replace_flashinfer_version(
file_path: Path, old_version: str, new_version: str
) -> bool:
if not file_path.exists():
print(f"Warning: {file_path} does not exist, skipping")
return False
content = file_path.read_text()
new_content = content
name = file_path.name
if name == "pyproject.toml":
new_content = new_content.replace(
f"flashinfer_python=={old_version}", f"flashinfer_python=={new_version}"
)
new_content = new_content.replace(
f"flashinfer_cubin=={old_version}", f"flashinfer_cubin=={new_version}"
)
elif name == "Dockerfile":
new_content = re.sub(
rf"(ARG FLASHINFER_VERSION=){re.escape(old_version)}",
rf"\g<1>{new_version}",
new_content,
)
elif name == "ci_install_dependency.sh":
new_content = re.sub(
rf"(FLASHINFER_VERSION=){re.escape(old_version)}",
rf"\g<1>{new_version}",
new_content,
)
elif name == "engine.py":
new_content = re.sub(
r'(assert_pkg_version\(\s*"flashinfer_python",\s*)"'
+ re.escape(old_version)
+ r'"',
r'\g<1>"' + new_version + '"',
new_content,
flags=re.DOTALL,
)
elif name == "common.py":
new_content = new_content.replace(
f'e.g., "{old_version}"',
f'e.g., "{new_version}"',
)
if content == new_content:
print(f"No changes needed in {file_path}")
return False
file_path.write_text(new_content)
print(f"✓ Updated {file_path}")
return True
def main():
parser = argparse.ArgumentParser(
description="Bump flashinfer version across all relevant files"
)
parser.add_argument(
"new_version",
help="New version (e.g., 0.6.4, 0.6.4rc0, or 0.6.4.post1)",
)
args = parser.parse_args()
new_version = normalize_version(args.new_version)
if not validate_version(new_version):
print(f"Error: Invalid version format: {new_version}")
print("Expected format: X.Y.Z, X.Y.ZrcN, or X.Y.Z.postN")
print("Examples: 0.6.4, 0.6.4rc0, 0.6.4.post1")
sys.exit(1)
repo_root = get_repo_root()
old_version = read_current_flashinfer_version(repo_root)
print(f"Current flashinfer version: {old_version}")
print(f"New flashinfer version: {new_version}")
print()
comparison = compare_versions(new_version, old_version)
if comparison == 0:
print("Error: New version is the same as current version")
sys.exit(1)
elif comparison < 0:
print(
f"Error: New version ({new_version}) is older than current version ({old_version})"
)
print("Version must be greater than the current version")
sys.exit(1)
updated_count = 0
for file_rel in FILES_TO_UPDATE:
file_abs = repo_root / file_rel
if replace_flashinfer_version(file_abs, old_version, new_version):
updated_count += 1
print()
print(f"Successfully updated {updated_count} file(s)")
print(f"Flashinfer version bumped from {old_version} to {new_version}")
print("\nValidating version updates...")
failed_files = []
for file_rel in FILES_TO_UPDATE:
file_abs = repo_root / file_rel
if not file_abs.exists():
print(f"Warning: File {file_rel} does not exist, skipping validation.")
continue
content = file_abs.read_text()
if new_version not in content:
failed_files.append(file_rel)
print(f"✗ {file_rel} does not contain version {new_version}")
else:
print(f"✓ {file_rel} validated")
if failed_files:
print(f"\nError: {len(failed_files)} file(s) were not updated correctly:")
for file_rel in failed_files:
print(f" - {file_rel}")
sys.exit(1)
print("\nAll files validated successfully!")
if __name__ == "__main__":
main()
|