#!/usr/bin/env python3 import argparse import re import sys from pathlib import Path from utils import compare_versions, get_repo_root, normalize_version, validate_version FILES_TO_UPDATE = [ Path("python/pyproject.toml"), Path("docker/Dockerfile"), Path("scripts/ci/cuda/ci_install_dependency.sh"), Path("python/sglang/srt/entrypoints/engine.py"), Path("python/sglang/srt/utils/common.py"), ] def read_current_flashinfer_version(repo_root: Path) -> str: """Read the current flashinfer version from python/pyproject.toml.""" pyproject = repo_root / "python" / "pyproject.toml" content = pyproject.read_text() match = re.search( r"flashinfer_python==(\d+\.\d+\.\d+(?:rc\d+|\.post\d+)?)", content ) if not match: raise ValueError(f"Could not find flashinfer_python version in {pyproject}") return match.group(1) def replace_flashinfer_version( file_path: Path, old_version: str, new_version: str ) -> bool: if not file_path.exists(): print(f"Warning: {file_path} does not exist, skipping") return False content = file_path.read_text() new_content = content name = file_path.name if name == "pyproject.toml": new_content = new_content.replace( f"flashinfer_python=={old_version}", f"flashinfer_python=={new_version}" ) new_content = new_content.replace( f"flashinfer_cubin=={old_version}", f"flashinfer_cubin=={new_version}" ) elif name == "Dockerfile": new_content = re.sub( rf"(ARG FLASHINFER_VERSION=){re.escape(old_version)}", rf"\g<1>{new_version}", new_content, ) elif name == "ci_install_dependency.sh": new_content = re.sub( rf"(FLASHINFER_VERSION=){re.escape(old_version)}", rf"\g<1>{new_version}", new_content, ) elif name == "engine.py": new_content = re.sub( r'(assert_pkg_version\(\s*"flashinfer_python",\s*)"' + re.escape(old_version) + r'"', r'\g<1>"' + new_version + '"', new_content, flags=re.DOTALL, ) elif name == "common.py": new_content = new_content.replace( f'e.g., "{old_version}"', f'e.g., "{new_version}"', ) if content == new_content: print(f"No changes needed in {file_path}") return False file_path.write_text(new_content) print(f"✓ Updated {file_path}") return True def main(): parser = argparse.ArgumentParser( description="Bump flashinfer version across all relevant files" ) parser.add_argument( "new_version", help="New version (e.g., 0.6.4, 0.6.4rc0, or 0.6.4.post1)", ) args = parser.parse_args() new_version = normalize_version(args.new_version) if not validate_version(new_version): print(f"Error: Invalid version format: {new_version}") print("Expected format: X.Y.Z, X.Y.ZrcN, or X.Y.Z.postN") print("Examples: 0.6.4, 0.6.4rc0, 0.6.4.post1") sys.exit(1) repo_root = get_repo_root() old_version = read_current_flashinfer_version(repo_root) print(f"Current flashinfer version: {old_version}") print(f"New flashinfer version: {new_version}") print() comparison = compare_versions(new_version, old_version) if comparison == 0: print("Error: New version is the same as current version") sys.exit(1) elif comparison < 0: print( f"Error: New version ({new_version}) is older than current version ({old_version})" ) print("Version must be greater than the current version") sys.exit(1) updated_count = 0 for file_rel in FILES_TO_UPDATE: file_abs = repo_root / file_rel if replace_flashinfer_version(file_abs, old_version, new_version): updated_count += 1 print() print(f"Successfully updated {updated_count} file(s)") print(f"Flashinfer version bumped from {old_version} to {new_version}") print("\nValidating version updates...") failed_files = [] for file_rel in FILES_TO_UPDATE: file_abs = repo_root / file_rel if not file_abs.exists(): print(f"Warning: File {file_rel} does not exist, skipping validation.") continue content = file_abs.read_text() if new_version not in content: failed_files.append(file_rel) print(f"✗ {file_rel} does not contain version {new_version}") else: print(f"✓ {file_rel} validated") if failed_files: print(f"\nError: {len(failed_files)} file(s) were not updated correctly:") for file_rel in failed_files: print(f" - {file_rel}") sys.exit(1) print("\nAll files validated successfully!") if __name__ == "__main__": main()