File size: 4,936 Bytes
61ba51e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
#!/usr/bin/env python3

import argparse
import re
import sys
from pathlib import Path

from utils import compare_versions, get_repo_root, normalize_version, validate_version

FILES_TO_UPDATE = [
    Path("python/pyproject.toml"),
    Path("docker/Dockerfile"),
    Path("scripts/ci/cuda/ci_install_dependency.sh"),
    Path("python/sglang/srt/entrypoints/engine.py"),
    Path("python/sglang/srt/utils/common.py"),
]


def read_current_flashinfer_version(repo_root: Path) -> str:
    """Read the current flashinfer version from python/pyproject.toml."""
    pyproject = repo_root / "python" / "pyproject.toml"
    content = pyproject.read_text()
    match = re.search(
        r"flashinfer_python==(\d+\.\d+\.\d+(?:rc\d+|\.post\d+)?)", content
    )
    if not match:
        raise ValueError(f"Could not find flashinfer_python version in {pyproject}")
    return match.group(1)


def replace_flashinfer_version(
    file_path: Path, old_version: str, new_version: str
) -> bool:
    if not file_path.exists():
        print(f"Warning: {file_path} does not exist, skipping")
        return False

    content = file_path.read_text()
    new_content = content

    name = file_path.name
    if name == "pyproject.toml":
        new_content = new_content.replace(
            f"flashinfer_python=={old_version}", f"flashinfer_python=={new_version}"
        )
        new_content = new_content.replace(
            f"flashinfer_cubin=={old_version}", f"flashinfer_cubin=={new_version}"
        )
    elif name == "Dockerfile":
        new_content = re.sub(
            rf"(ARG FLASHINFER_VERSION=){re.escape(old_version)}",
            rf"\g<1>{new_version}",
            new_content,
        )
    elif name == "ci_install_dependency.sh":
        new_content = re.sub(
            rf"(FLASHINFER_VERSION=){re.escape(old_version)}",
            rf"\g<1>{new_version}",
            new_content,
        )
    elif name == "engine.py":
        new_content = re.sub(
            r'(assert_pkg_version\(\s*"flashinfer_python",\s*)"'
            + re.escape(old_version)
            + r'"',
            r'\g<1>"' + new_version + '"',
            new_content,
            flags=re.DOTALL,
        )
    elif name == "common.py":
        new_content = new_content.replace(
            f'e.g., "{old_version}"',
            f'e.g., "{new_version}"',
        )

    if content == new_content:
        print(f"No changes needed in {file_path}")
        return False

    file_path.write_text(new_content)
    print(f"✓ Updated {file_path}")
    return True


def main():
    parser = argparse.ArgumentParser(
        description="Bump flashinfer version across all relevant files"
    )
    parser.add_argument(
        "new_version",
        help="New version (e.g., 0.6.4, 0.6.4rc0, or 0.6.4.post1)",
    )
    args = parser.parse_args()

    new_version = normalize_version(args.new_version)

    if not validate_version(new_version):
        print(f"Error: Invalid version format: {new_version}")
        print("Expected format: X.Y.Z, X.Y.ZrcN, or X.Y.Z.postN")
        print("Examples: 0.6.4, 0.6.4rc0, 0.6.4.post1")
        sys.exit(1)

    repo_root = get_repo_root()
    old_version = read_current_flashinfer_version(repo_root)
    print(f"Current flashinfer version: {old_version}")
    print(f"New flashinfer version: {new_version}")
    print()

    comparison = compare_versions(new_version, old_version)
    if comparison == 0:
        print("Error: New version is the same as current version")
        sys.exit(1)
    elif comparison < 0:
        print(
            f"Error: New version ({new_version}) is older than current version ({old_version})"
        )
        print("Version must be greater than the current version")
        sys.exit(1)

    updated_count = 0
    for file_rel in FILES_TO_UPDATE:
        file_abs = repo_root / file_rel
        if replace_flashinfer_version(file_abs, old_version, new_version):
            updated_count += 1

    print()
    print(f"Successfully updated {updated_count} file(s)")
    print(f"Flashinfer version bumped from {old_version} to {new_version}")

    print("\nValidating version updates...")
    failed_files = []
    for file_rel in FILES_TO_UPDATE:
        file_abs = repo_root / file_rel
        if not file_abs.exists():
            print(f"Warning: File {file_rel} does not exist, skipping validation.")
            continue

        content = file_abs.read_text()
        if new_version not in content:
            failed_files.append(file_rel)
            print(f"✗ {file_rel} does not contain version {new_version}")
        else:
            print(f"✓ {file_rel} validated")

    if failed_files:
        print(f"\nError: {len(failed_files)} file(s) were not updated correctly:")
        for file_rel in failed_files:
            print(f"  - {file_rel}")
        sys.exit(1)

    print("\nAll files validated successfully!")


if __name__ == "__main__":
    main()