Hanrui / sglang /scripts /ci /utils /slash_command_handler.py
Lekr0's picture
Add files using upload-large-folder tool
61ba51e verified
import glob
import json
import os
import re
import sys
import time
from datetime import datetime, timezone
import requests
from github import Auth, Github
# Configuration
PERMISSIONS_FILE_PATH = ".github/CI_PERMISSIONS.json"
def find_workflow_run_url(
gh_repo,
workflow_id,
ref,
target_stage,
token,
dispatch_time,
pr_head_sha=None,
max_wait=30,
):
"""
Poll for the workflow run URL after dispatch.
Uses the dynamic run-name feature to identify runs:
- Fork PRs: display_title = "[stage-name] sha"
- Non-fork PRs: display_title = "[stage-name]"
Args:
gh_repo: PyGithub repository object
workflow_id: ID of the workflow that was dispatched
ref: Branch/ref the workflow was dispatched on
target_stage: The stage name we're looking for
token: GitHub API token
dispatch_time: Unix timestamp when dispatch was triggered
pr_head_sha: PR head SHA (for fork PRs, used to match display_title)
max_wait: Maximum seconds to wait for the run to appear
Returns:
The workflow run URL if found, None otherwise.
"""
# Build expected display_title pattern based on workflow's run-name
# Format: "[stage-name] sha" for fork PRs, "[stage-name]" for non-fork
if pr_head_sha:
expected_title = f"[{target_stage}] {pr_head_sha}"
else:
expected_title = f"[{target_stage}]"
print(f"Looking for workflow run with display_title: {expected_title}")
for attempt in range(max_wait // 5):
time.sleep(5)
# Get recent workflow_dispatch runs for this workflow
runs_url = f"https://api.github.com/repos/{gh_repo.full_name}/actions/workflows/{workflow_id}/runs"
runs_resp = requests.get(
runs_url,
params={"event": "workflow_dispatch", "branch": ref, "per_page": 10},
headers={
"Authorization": f"Bearer {token}",
"Accept": "application/vnd.github+json",
},
)
if runs_resp.status_code != 200:
print(f"Failed to fetch workflow runs: {runs_resp.status_code}")
continue
for run in runs_resp.json().get("workflow_runs", []):
# Skip runs created before our dispatch (with 10s tolerance)
run_created = datetime.fromisoformat(
run["created_at"].replace("Z", "+00:00")
).timestamp()
if run_created < dispatch_time - 10:
continue
# Match by display_title (set by workflow's run-name directive)
# This is immediately available, unlike job names which require waiting
display_title = run.get("display_title", "")
if display_title == expected_title:
print(
f"Found matching workflow run: {run['id']} with title '{display_title}'"
)
return run["html_url"]
print(f"Could not find workflow run after {max_wait} seconds")
return None
def get_env_var(name):
val = os.getenv(name)
if not val:
print(f"Error: Environment variable {name} not set.")
sys.exit(1)
return val
def load_permissions(user_login):
"""
Reads the permissions JSON from the local file system and returns
the permissions dict for the specific user.
"""
try:
print(f"Loading permissions from {PERMISSIONS_FILE_PATH}...")
if not os.path.exists(PERMISSIONS_FILE_PATH):
print(f"Error: Permissions file not found at {PERMISSIONS_FILE_PATH}")
return None
with open(PERMISSIONS_FILE_PATH, "r") as f:
data = json.load(f)
user_perms = data.get(user_login)
if not user_perms:
print(f"User '{user_login}' not found in permissions file.")
return None
return user_perms
except Exception as e:
print(f"Failed to load or parse permissions file: {e}")
sys.exit(1)
def has_sgl_kernel_changes(pr):
"""
Check if the PR has changes to the sgl-kernel directory.
This is used to determine if we need a full workflow rerun
(to rebuild the kernel) vs just rerunning failed jobs.
"""
try:
files = pr.get_files()
for f in files:
if f.filename.startswith("sgl-kernel/"):
return True
return False
except Exception as e:
print(f"Warning: Could not check PR files for sgl-kernel changes: {e}")
# Default to False to avoid unnecessary full reruns
return False
def handle_tag_run_ci(gh_repo, pr, comment, user_perms, react_on_success=True):
"""
Handles the /tag-run-ci-label command.
Returns True if action was taken, False otherwise.
"""
if not user_perms.get("can_tag_run_ci_label", False):
print("Permission denied: can_tag_run_ci_label is false.")
return False
print("Permission granted. Adding 'run-ci' label.")
pr.add_to_labels("run-ci")
if react_on_success:
comment.create_reaction("+1")
print("Label added and comment reacted.")
else:
print("Label added (reaction suppressed).")
return True
def handle_rerun_failed_ci(gh_repo, pr, comment, user_perms, react_on_success=True):
"""
Handles the /rerun-failed-ci command.
Reruns workflows with 'failure' or 'skipped' conclusions.
Returns True if action was taken, False otherwise.
"""
if not user_perms.get("can_rerun_failed_ci", False):
print("Permission denied: can_rerun_failed_ci is false.")
return False
print("Permission granted. Triggering rerun of failed or skipped workflows.")
# Check if PR has sgl-kernel changes - if so, we need full reruns
# to ensure sgl-kernel-build-wheels runs and produces fresh artifacts
sgl_kernel_changes = has_sgl_kernel_changes(pr)
if sgl_kernel_changes:
print("PR has sgl-kernel changes - will use full rerun to rebuild kernel")
# Get the SHA of the latest commit in the PR
head_sha = pr.head.sha
print(f"Checking workflows for commit: {head_sha}")
# List all workflow runs for this commit
runs = gh_repo.get_workflow_runs(head_sha=head_sha)
rerun_count = 0
for run in runs:
if run.status != "completed":
continue
if run.conclusion == "failure":
print(f"Rerunning failed workflow: {run.name} (ID: {run.id})")
try:
if sgl_kernel_changes:
# Full rerun to ensure sgl-kernel-build-wheels runs
# and produces fresh artifacts for dependent jobs
run.rerun()
else:
# Use rerun_failed_jobs for efficiency on failures
run.rerun_failed_jobs()
rerun_count += 1
except Exception as e:
print(f"Failed to rerun workflow {run.id}: {e}")
elif run.conclusion == "skipped":
print(f"Rerunning skipped workflow: {run.name} (ID: {run.id})")
try:
# Skipped workflows don't have 'failed jobs', so we use full rerun()
run.rerun()
rerun_count += 1
except Exception as e:
print(f"Failed to rerun workflow {run.id}: {e}")
if rerun_count > 0:
print(f"Triggered rerun for {rerun_count} workflows.")
if react_on_success:
comment.create_reaction("+1")
return True
else:
print("No failed or skipped workflows found to rerun.")
return False
def handle_rerun_stage(
gh_repo, pr, comment, user_perms, stage_name, token, react_on_success=True
):
"""
Handles the /rerun-stage <stage-name> command.
Triggers a workflow_dispatch to run only the specified stage, skipping dependencies.
Returns True if action was taken, False otherwise.
"""
if not user_perms.get("can_rerun_stage", False):
print("Permission denied: can_rerun_stage is false.")
return False
if not stage_name:
print("Error: No stage name provided")
comment.create_reaction("confused")
pr.create_issue_comment(
f"❌ Please specify a stage name: `/rerun-stage <stage-name>`\n\n"
f"Examples: `/rerun-stage unit-test-backend-4-gpu`, `/rerun-stage accuracy-test-1-gpu`"
)
return False
print(f"Permission granted. Triggering workflow_dispatch for stage '{stage_name}'.")
# Valid NVIDIA stage names that support target_stage
nvidia_stages = [
"stage-a-test-1",
"stage-a-cpu-only",
"stage-b-test-small-1-gpu",
"stage-b-test-large-1-gpu",
"stage-b-test-large-2-gpu",
"stage-b-test-4-gpu-b200",
"stage-c-test-4-gpu-h100",
"stage-c-test-8-gpu-h200",
"stage-c-test-8-gpu-h20",
"stage-c-test-4-gpu-b200",
"stage-c-test-4-gpu-gb200",
"stage-c-test-deepep-4-gpu",
"stage-c-test-deepep-8-gpu-h200",
"multimodal-gen-test-1-gpu",
"multimodal-gen-test-2-gpu",
]
# Valid AMD stage names that support target_stage
amd_stages = [
"sgl-kernel-unit-test-amd",
"sgl-kernel-unit-test-2-gpu-amd",
"stage-a-test-1-amd",
"stage-b-test-small-1-gpu-amd",
"stage-b-test-small-1-gpu-amd-nondeterministic",
"stage-b-test-small-1-gpu-amd-mi35x",
"stage-b-test-large-1-gpu-amd",
"stage-b-test-large-2-gpu-amd",
"multimodal-gen-test-1-gpu-amd",
"multimodal-gen-test-2-gpu-amd",
"stage-c-test-large-8-gpu-amd",
"stage-c-test-large-8-gpu-amd-mi35x",
]
valid_stages = nvidia_stages + amd_stages
is_amd_stage = stage_name in amd_stages
if stage_name not in valid_stages:
comment.create_reaction("confused")
pr.create_issue_comment(
f"❌ Stage `{stage_name}` doesn't support isolated runs yet.\n\n"
f"**NVIDIA stages:**\n"
+ "\n".join(f"- `{s}`" for s in nvidia_stages)
+ "\n\n**AMD stages:**\n"
+ "\n".join(f"- `{s}`" for s in amd_stages)
+ "\n\nOther stages will be added soon. For now, use `/rerun-failed-ci` for those stages."
)
return False
try:
# Get the appropriate workflow based on stage type
workflow_name = "PR Test (AMD)" if is_amd_stage else "PR Test"
workflows = gh_repo.get_workflows()
target_workflow = None
for wf in workflows:
if wf.name == workflow_name:
target_workflow = wf
break
if not target_workflow:
print(f"Error: {workflow_name} workflow not found")
return False
# Check if PR is from a fork by comparing repo owners
# Handle case where fork repo may have been deleted (pr.head.repo is None)
is_fork = (
pr.head.repo is None or pr.head.repo.owner.login != gh_repo.owner.login
)
print(f"PR is from fork: {is_fork}")
# pr_head_sha is used for fork PRs (passed to workflow and used for URL lookup)
pr_head_sha = None
if is_fork:
# For fork PRs: dispatch on main and pass SHA as input
# This is needed because fork branch names don't exist in the main repo
ref = "main"
pr_head_sha = pr.head.sha
print(
f"Triggering {workflow_name} workflow on ref: {ref}, PR head SHA: {pr_head_sha}"
)
if is_amd_stage:
inputs = {"target_stage": stage_name, "pr_head_sha": pr_head_sha}
else:
inputs = {
"version": "release",
"target_stage": stage_name,
"pr_head_sha": pr_head_sha,
}
else:
# For non-fork PRs: dispatch on the PR branch directly
# This allows testing workflow changes before merge
ref = pr.head.ref
print(f"Triggering {workflow_name} workflow on branch: {ref}")
if is_amd_stage:
inputs = {"target_stage": stage_name}
else:
inputs = {"version": "release", "target_stage": stage_name}
# Record dispatch time before triggering
dispatch_time = time.time()
# Use requests directly as PyGithub's create_dispatch only accepts HTTP 204
dispatch_url = f"https://api.github.com/repos/{gh_repo.full_name}/actions/workflows/{target_workflow.id}/dispatches"
dispatch_resp = requests.post(
dispatch_url,
json={"ref": ref, "inputs": inputs},
headers={
"Authorization": f"Bearer {token}",
"Accept": "application/vnd.github+json",
},
)
success = dispatch_resp.status_code in (200, 204)
if not success:
print(f"Dispatch failed: {dispatch_resp.status_code} {dispatch_resp.text}")
if success:
print(f"Successfully triggered workflow for stage '{stage_name}'")
if react_on_success:
comment.create_reaction("+1")
pr.create_issue_comment(
f"✅ Triggered `{stage_name}` to run independently (skipping dependencies)."
)
# Poll for the workflow run URL and post follow-up comment
run_url = find_workflow_run_url(
gh_repo,
target_workflow.id,
ref,
stage_name,
token,
dispatch_time,
pr_head_sha=pr_head_sha,
max_wait=30,
)
if run_url:
pr.create_issue_comment(f"🔗 [View workflow run]({run_url})")
else:
pr.create_issue_comment(
f"⚠️ Could not retrieve workflow run URL. "
f"Check the [Actions tab](https://github.com/{gh_repo.full_name}/actions) for progress."
)
return True
else:
print("Failed to trigger workflow_dispatch")
return False
except Exception as e:
print(f"Error triggering workflow_dispatch: {e}")
comment.create_reaction("confused")
pr.create_issue_comment(
f"❌ Failed to trigger workflow: {str(e)}\n\n"
f"Please check the logs or contact maintainers."
)
return False
CUDA_SUITE_TO_RUNNER = {
"stage-a-test-1": "1-gpu-runner",
"stage-a-cpu-only": "ubuntu-latest",
"stage-b-test-small-1-gpu": "1-gpu-5090",
"stage-b-test-large-1-gpu": "1-gpu-runner",
"stage-b-test-large-2-gpu": "2-gpu-runner",
"stage-b-test-4-gpu-b200": "4-gpu-b200",
"stage-c-test-4-gpu-h100": "4-gpu-h100",
"stage-c-test-8-gpu-h200": "8-gpu-h200",
"stage-c-test-8-gpu-h20": "8-gpu-h20",
"stage-c-test-4-gpu-b200": "4-gpu-b200",
"stage-c-test-deepep-4-gpu": "4-gpu-h100",
"stage-c-test-deepep-8-gpu-h200": "8-gpu-h200",
}
DEEPEP_SUITES = {
"stage-c-test-8-gpu-h20",
"stage-c-test-deepep-4-gpu",
"stage-c-test-deepep-8-gpu-h200",
}
def resolve_test_file(file_part):
"""
Resolve a user-provided file path to a path relative to test/.
Supports:
- Full path: test/registered/core/test_srt_endpoint.py
- Relative to test/: registered/core/test_srt_endpoint.py
- Bare filename: test_srt_endpoint.py (glob-matched, must be unique)
Returns (resolved_path, error_message). On success error_message is None.
"""
if file_part.startswith("test/"):
file_part = file_part[len("test/") :]
if "/" not in file_part:
matches = glob.glob(f"test/registered/**/{file_part}", recursive=True)
if len(matches) == 0:
return (
None,
f"No test file found matching `{file_part}` under `test/registered/`.",
)
if len(matches) > 1:
match_list = "\n".join(f"- `{m}`" for m in sorted(matches))
return None, (
f"Ambiguous filename `{file_part}` — matched {len(matches)} files:\n\n"
f"{match_list}\n\n"
f"Please provide the full path, e.g. `/rerun-ut {matches[0]}`"
)
return matches[0][len("test/") :], None
full_path = f"test/{file_part}"
if not os.path.isfile(full_path):
return None, f"File not found: `{full_path}`"
return file_part, None
def detect_cuda_suite(file_path_from_test):
"""
Read a test file and extract the suite from register_cuda_ci(suite="...").
Returns (suite_name, runner_label, use_deepep, error_message).
"""
full_path = f"test/{file_path_from_test}"
with open(full_path, "r") as f:
content = f.read()
match = re.search(
r'register_cuda_ci\([^)]*suite\s*=\s*["\']([^"\']+)["\']', content
)
if not match:
return (
None,
None,
False,
(
f"No `register_cuda_ci()` found in `{full_path}`.\n\n"
f"This file may not be a registered CUDA CI test."
),
)
suite = match.group(1)
runner = CUDA_SUITE_TO_RUNNER.get(suite)
if not runner:
known = ", ".join(f"`{s}`" for s in sorted(CUDA_SUITE_TO_RUNNER))
return (
suite,
None,
False,
(
f"Unknown CUDA suite `{suite}` in `{full_path}`.\n\n"
f"Known suites: {known}"
),
)
use_deepep = suite in DEEPEP_SUITES
return suite, runner, use_deepep, None
def handle_rerun_ut(gh_repo, pr, comment, user_perms, test_spec, token):
"""
Handles the /rerun-ut <file>::<TestClass.test_method> command.
Dispatches a lightweight workflow to run a single test on the correct CUDA runner.
"""
if not (
user_perms.get("can_rerun_ut", False)
or user_perms.get("can_rerun_stage", False)
):
print("Permission denied: neither can_rerun_ut nor can_rerun_stage is true.")
return False
if not test_spec:
comment.create_reaction("confused")
pr.create_issue_comment(
"❌ Please specify a test: `/rerun-ut <file>::<TestClass.test_method>`\n\n"
"Examples:\n"
"- `/rerun-ut test/registered/core/test_srt_endpoint.py::TestSRTEndpoint.test_simple_decode`\n"
"- `/rerun-ut registered/core/test_srt_endpoint.py::TestSRTEndpoint`\n"
"- `/rerun-ut test_srt_endpoint.py`"
)
return False
# Parse spec: split on :: to get file path and optional test selector
if "::" in test_spec:
file_part, test_selector = test_spec.split("::", 1)
else:
file_part = test_spec
test_selector = None
file_part = file_part.strip()
if test_selector:
test_selector = test_selector.strip()
# Resolve file path
resolved_path, err = resolve_test_file(file_part)
if err:
comment.create_reaction("confused")
pr.create_issue_comment(f"❌ {err}")
return False
# Detect suite and runner
suite, runner_label, use_deepep, err = detect_cuda_suite(resolved_path)
if err:
comment.create_reaction("confused")
pr.create_issue_comment(f"❌ {err}")
return False
# Build test_command: file path (+ optional test selector as unittest arg)
test_command = resolved_path
if test_selector:
test_command = f"{resolved_path} {test_selector}"
print(
f"Resolved: file={resolved_path}, selector={test_selector}, "
f"suite={suite}, runner={runner_label}, deepep={use_deepep}, command='{test_command}'"
)
try:
workflow_name = "Rerun UT"
workflows = gh_repo.get_workflows()
target_workflow = None
for wf in workflows:
if wf.name == workflow_name:
target_workflow = wf
break
if not target_workflow:
print(f"Error: {workflow_name} workflow not found")
return False
is_fork = (
pr.head.repo is None or pr.head.repo.owner.login != gh_repo.owner.login
)
print(f"PR is from fork: {is_fork}")
pr_head_sha = None
if is_fork:
ref = "main"
pr_head_sha = pr.head.sha
inputs = {
"test_command": test_command,
"runner_label": runner_label,
"pr_head_sha": pr_head_sha,
"use_deepep": str(use_deepep).lower(),
}
else:
ref = pr.head.ref
inputs = {
"test_command": test_command,
"runner_label": runner_label,
"use_deepep": str(use_deepep).lower(),
}
dispatch_time = time.time()
dispatch_url = f"https://api.github.com/repos/{gh_repo.full_name}/actions/workflows/{target_workflow.id}/dispatches"
dispatch_resp = requests.post(
dispatch_url,
json={"ref": ref, "inputs": inputs},
headers={
"Authorization": f"Bearer {token}",
"Accept": "application/vnd.github+json",
},
)
success = dispatch_resp.status_code in (200, 204)
if not success:
print(f"Dispatch failed: {dispatch_resp.status_code} {dispatch_resp.text}")
if success:
print(f"Successfully triggered rerun-ut: {test_command}")
comment.create_reaction("+1")
pr.create_issue_comment(
f"✅ Triggered `/rerun-ut` on `{runner_label}` runner:\n"
f"```\ncd test/ && python3 {test_command}\n```"
)
run_url = find_workflow_run_url(
gh_repo,
target_workflow.id,
ref,
"rerun-ut",
token,
dispatch_time,
pr_head_sha=pr_head_sha,
max_wait=30,
)
if run_url:
pr.create_issue_comment(f"🔗 [View workflow run]({run_url})")
else:
pr.create_issue_comment(
f"⚠️ Could not retrieve workflow run URL. "
f"Check the [Actions tab](https://github.com/{gh_repo.full_name}/actions) for progress."
)
return True
else:
print("Failed to trigger workflow_dispatch")
return False
except Exception as e:
print(f"Error triggering rerun-ut: {e}")
comment.create_reaction("confused")
pr.create_issue_comment(
f"❌ Failed to trigger rerun-ut: {str(e)}\n\n"
f"Please check the logs or contact maintainers."
)
return False
def main():
# 1. Load Environment Variables
token = get_env_var("GITHUB_TOKEN")
repo_name = get_env_var("REPO_FULL_NAME")
pr_number = int(get_env_var("PR_NUMBER"))
comment_id = int(get_env_var("COMMENT_ID"))
comment_body = get_env_var("COMMENT_BODY").strip()
user_login = get_env_var("USER_LOGIN")
# 2. Load Permissions (local file check first to avoid unnecessary API calls)
user_perms = load_permissions(user_login)
# 3. Initialize GitHub API with Auth
auth = Auth.Token(token)
g = Github(auth=auth)
repo = g.get_repo(repo_name)
pr = repo.get_pull(pr_number)
comment = repo.get_issue(pr_number).get_comment(comment_id)
# PR authors can always rerun failed CI and rerun individual UTs on their own PRs,
# even if they are not listed in CI_PERMISSIONS.json.
# Note: /tag-run-ci-label and /rerun-stage still require CI_PERMISSIONS.json.
if pr.user.login == user_login:
if user_perms is None:
print(
f"User {user_login} is the PR author (not in CI_PERMISSIONS.json). "
"Granting CI rerun permissions."
)
user_perms = {}
else:
print(
f"User {user_login} is the PR author and has existing CI permissions."
)
user_perms["can_rerun_failed_ci"] = True
user_perms["can_rerun_ut"] = True
if not user_perms:
print(f"User {user_login} does not have any configured permissions. Exiting.")
return
# 4. Parse Command and Execute
first_line = comment_body.split("\n")[0].strip()
if first_line.startswith("/tag-run-ci-label"):
handle_tag_run_ci(repo, pr, comment, user_perms)
elif first_line.startswith("/rerun-failed-ci"):
handle_rerun_failed_ci(repo, pr, comment, user_perms)
elif first_line.startswith("/tag-and-rerun-ci"):
# Perform both actions, but suppress individual reactions
print("Processing combined command: /tag-and-rerun-ci")
tagged = handle_tag_run_ci(
repo, pr, comment, user_perms, react_on_success=False
)
# Wait for the label to propagate before triggering rerun
if tagged:
print("Waiting 5 seconds for label to propagate...")
time.sleep(5)
rerun = handle_rerun_failed_ci(
repo, pr, comment, user_perms, react_on_success=False
)
# If at least one action was successful, add the reaction here
if tagged or rerun:
comment.create_reaction("+1")
print("Combined command processed successfully; reaction added.")
else:
print("Combined command finished, but no actions were taken.")
elif first_line.startswith("/rerun-stage"):
# Extract stage name from command
parts = first_line.split(maxsplit=1)
stage_name = parts[1].strip() if len(parts) > 1 else None
handle_rerun_stage(repo, pr, comment, user_perms, stage_name, token)
elif first_line.startswith("/rerun-ut"):
parts = first_line.split(maxsplit=1)
test_spec = parts[1].strip() if len(parts) > 1 else None
handle_rerun_ut(repo, pr, comment, user_perms, test_spec, token)
else:
print(f"Unknown or ignored command: {first_line}")
if __name__ == "__main__":
main()