NeMo
Megatron-LM / .github /scripts /oncall_manager.py
KexuanShi's picture
Upload folder using huggingface_hub
88e6849 verified
Raw
History Blame Contribute Delete
16.7 kB
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import json
import requests
import argparse
from datetime import datetime, timedelta, timezone
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
# Constants
GITHUB_API_URL = "https://api.github.com"
SCHEDULE_FILE = ".github/oncall_schedule.json"
ROTATION_TEAM_SLUG = "mcore-oncall-rotation"
ACTIVE_ONCALL_TEAM_SLUG = "mcore-oncall"
SLACK_USERGROUP_HANDLE = "mcore-oncall"
TARGET_WEEKS = 12
# Caches for email and Slack lookups
_email_cache = {}
_slack_id_cache = {}
def get_headers():
token = os.environ.get("GH_TOKEN")
if not token:
# Fallback to GITHUB_TOKEN if GH_TOKEN not set
token = os.environ.get("GITHUB_TOKEN")
if not token:
print("Error: GH_TOKEN or GITHUB_TOKEN not set")
sys.exit(1)
return {
"Authorization": f"token {token}",
"Accept": "application/vnd.github.v3+json"
}
def get_repo_info():
"""Returns (owner, repo) from GITHUB_REPOSITORY env var."""
repo_env = os.environ.get("GITHUB_REPOSITORY")
if not repo_env:
print("Error: GITHUB_REPOSITORY environment variable not set")
sys.exit(1)
parts = repo_env.split("/")
return parts[0], parts[1]
def get_team_members(org, team_slug):
"""Fetches members of the GitHub team."""
url = f"{GITHUB_API_URL}/orgs/{org}/teams/{team_slug}/members"
headers = get_headers()
members = set()
page = 1
while True:
resp = requests.get(f"{url}?per_page=100&page={page}", headers=headers)
if resp.status_code != 200:
print(f"Error fetching team members: {resp.status_code} {resp.text}")
sys.exit(1)
data = resp.json()
if not data:
break
members.update([m['login'] for m in data])
if len(data) < 100:
break
page += 1
return members
def get_user_email(username):
"""Get user's email from GitHub, prioritizing @nvidia.com emails.
Checks in order:
1. Public profile email
2. Recent commits in the repository
"""
if username in _email_cache:
return _email_cache[username]
headers = get_headers()
public_email = None
try:
# 1. Try to get user's public profile email first
resp = requests.get(f"{GITHUB_API_URL}/users/{username}", headers=headers)
if resp.status_code == 200:
user_data = resp.json()
email = user_data.get('email')
if email and not email.endswith("@users.noreply.github.com"):
if email.endswith("@nvidia.com"):
_email_cache[username] = email
return email
# Store non-nvidia email as fallback
public_email = email
# 2. Check recent commits in the repository for @nvidia.com email
repo_env = os.environ.get("GITHUB_REPOSITORY", "NVIDIA/Megatron-LM")
commits_url = f"{GITHUB_API_URL}/repos/{repo_env}/commits?author={username}&per_page=10"
resp = requests.get(commits_url, headers=headers)
if resp.status_code == 200:
commits = resp.json()
for commit in commits:
# Get email from commit author
commit_data = commit.get('commit', {})
author_data = commit_data.get('author', {})
email = author_data.get('email')
if email and not email.endswith("@users.noreply.github.com"):
if email.endswith("@nvidia.com"):
_email_cache[username] = email
print(f"Found @nvidia.com email for {username} from commits: {email}")
return email
elif public_email is None:
public_email = email
# 3. Use public email if found, otherwise fallback
if public_email:
_email_cache[username] = public_email
print(f"Using public email for {username}: {public_email}")
return public_email
# Fallback to noreply email
fallback = f"{username}@users.noreply.github.com"
_email_cache[username] = fallback
print(f"Warning: No email found for {username}, using fallback: {fallback}")
return fallback
except Exception as e:
print(f"Warning: Could not get email for {username}: {e}")
fallback = f"{username}@users.noreply.github.com"
_email_cache[username] = fallback
return fallback
def get_slack_client():
"""Get Slack WebClient if token is available."""
slack_token = os.environ.get("SLACK_TOKEN")
if not slack_token:
return None
return WebClient(token=slack_token)
def get_slack_user_id(slack_client, email):
"""Get Slack user ID from email."""
if not slack_client:
return None
if email in _slack_id_cache:
return _slack_id_cache[email]
try:
response = slack_client.users_lookupByEmail(email=email)
user_id = response["user"]["id"]
_slack_id_cache[email] = user_id
return user_id
except SlackApiError as e:
print(f"Warning: Could not find Slack user for {email}: {e.response['error']}")
_slack_id_cache[email] = None
return None
def get_slack_usergroup_id(slack_client, handle):
"""Get Slack usergroup ID from handle."""
if not slack_client:
return None
try:
response = slack_client.usergroups_list(include_users=True)
for usergroup in response.get("usergroups", []):
if usergroup.get("handle") == handle:
return usergroup.get("id"), usergroup.get("users", [])
print(f"Warning: Slack usergroup '{handle}' not found")
return None, []
except SlackApiError as e:
print(f"Warning: Could not list Slack usergroups: {e.response['error']}")
return None, []
def update_slack_usergroup(new_oncall_username, old_members_usernames):
"""
Updates the Slack usergroup to contain only the new oncall user.
Adds new oncall first, then removes old members (usergroups need at least one member).
"""
slack_client = get_slack_client()
if not slack_client:
print("Slack token not configured, skipping Slack usergroup update")
return
# Get the new oncall's email and Slack user ID
new_email = get_user_email(new_oncall_username)
new_slack_id = get_slack_user_id(slack_client, new_email)
if not new_slack_id:
print(f"Could not find Slack user ID for {new_oncall_username} ({new_email}), skipping Slack update")
return
# Get the usergroup ID and current members
usergroup_id, current_slack_members = get_slack_usergroup_id(slack_client, SLACK_USERGROUP_HANDLE)
if not usergroup_id:
print(f"Could not find Slack usergroup '{SLACK_USERGROUP_HANDLE}', skipping Slack update")
return
try:
# Step 1: Add new oncall first (include current members to avoid removing anyone yet)
# This ensures usergroup always has at least one member
if new_slack_id not in current_slack_members:
updated_members = list(set(current_slack_members + [new_slack_id]))
slack_client.usergroups_users_update(
usergroup=usergroup_id,
users=updated_members
)
print(f"Added {new_oncall_username} to Slack usergroup '{SLACK_USERGROUP_HANDLE}'")
# Step 2: Now set the usergroup to contain only the new oncall
slack_client.usergroups_users_update(
usergroup=usergroup_id,
users=[new_slack_id]
)
print(f"Updated Slack usergroup '{SLACK_USERGROUP_HANDLE}' to contain only {new_oncall_username}")
except SlackApiError as e:
print(f"Failed to update Slack usergroup: {e.response['error']}")
def load_schedule():
if not os.path.exists(SCHEDULE_FILE):
return []
try:
with open(SCHEDULE_FILE, 'r') as f:
data = json.load(f)
# Normalize to list of dicts if it's a list of strings
schedule = []
for item in data:
if isinstance(item, str):
schedule.append({"user": item, "date": "YYYY-MM-DD"})
else:
schedule.append(item)
return schedule
except (json.JSONDecodeError, FileNotFoundError):
return []
def save_schedule(schedule):
with open(SCHEDULE_FILE, 'w') as f:
json.dump(schedule, f, indent=4)
f.write('\n') # trailing newline
def update_active_oncall_team(org, new_oncall):
"""Updates the active oncall team to contain only the new oncall user."""
# 1. Get current members of the active team
current_members = get_team_members(org, ACTIVE_ONCALL_TEAM_SLUG)
# 2. Add the new oncall if not present
if new_oncall not in current_members:
url = f"{GITHUB_API_URL}/orgs/{org}/teams/{ACTIVE_ONCALL_TEAM_SLUG}/memberships/{new_oncall}"
resp = requests.put(url, headers=get_headers())
if resp.status_code == 200:
print(f"Added {new_oncall} to {ACTIVE_ONCALL_TEAM_SLUG}")
else:
print(f"Failed to add {new_oncall} to {ACTIVE_ONCALL_TEAM_SLUG}: {resp.status_code} {resp.text}")
# 3. Remove everyone else
old_members = []
for member in current_members:
if member not in [new_oncall, 'svcnvidia-nemo-ci']:
old_members.append(member)
url = f"{GITHUB_API_URL}/orgs/{org}/teams/{ACTIVE_ONCALL_TEAM_SLUG}/memberships/{member}"
resp = requests.delete(url, headers=get_headers())
if resp.status_code == 204:
print(f"Removed {member} from {ACTIVE_ONCALL_TEAM_SLUG}")
else:
print(f"Failed to remove {member} from {ACTIVE_ONCALL_TEAM_SLUG}: {resp.status_code} {resp.text}")
# 4. Update Slack usergroup (add new oncall first, then remove old members)
update_slack_usergroup(new_oncall, old_members)
def rotate_schedule(repo_owner, dry_run=False):
schedule = load_schedule()
print(f"Current schedule length: {len(schedule)}")
# 1. Rotate (Remove past week)
# Only if schedule is not empty.
if schedule:
# Check date of first entry
first_entry = schedule[0]
try:
# We assume the date is the *start* of the oncall shift (Wednesday).
# The shift ends 7 days later.
start_date = datetime.strptime(first_entry['date'], "%Y-%m-%d").date()
end_date = start_date + timedelta(days=7)
today = datetime.now(timezone.utc).date()
# If today is >= end_date, the shift is over.
# (e.g. Started last Wed, ends today Wed. If today is Wed, we rotate)
if today >= end_date:
removed = schedule.pop(0)
print(f"Rotated out: {removed} (Ended {end_date})")
else:
print(f"First entry {first_entry} has not ended yet (Ends {end_date}). Not removing.")
except ValueError:
# Fallback if date is invalid, rotate anyway
removed = schedule.pop(0)
print(f"Rotated out (invalid date): {removed}")
else:
print("Schedule empty, nothing to rotate.")
# 2. Replenish
ensure_schedule_filled(schedule, repo_owner)
# 3. Update active oncall team
if schedule:
current_oncall = schedule[0]['user']
print(f"New active oncall: {current_oncall}")
if not dry_run:
update_active_oncall_team(repo_owner, current_oncall)
else:
print(f"Dry run: Would update {ACTIVE_ONCALL_TEAM_SLUG} to contain only {current_oncall}")
if not dry_run:
save_schedule(schedule)
print("Schedule updated and saved.")
else:
print("Dry run: Schedule not saved.")
print(json.dumps(schedule, indent=4))
def get_last_wednesday():
today = datetime.now(timezone.utc).date()
# Monday=0, Wednesday=2
offset = (today.weekday() - 2) % 7
return today - timedelta(days=offset)
def ensure_schedule_filled(schedule, repo_owner):
"""Appends users to schedule until it reaches TARGET_WEEKS."""
members = get_team_members(repo_owner, ROTATION_TEAM_SLUG)
if not members:
print(f"Warning: No team members found in {ROTATION_TEAM_SLUG}.")
return
if 'svcnvidia-nemo-ci' in members:
members.remove('svcnvidia-nemo-ci')
members = list(members)
members.sort() # Deterministic order
while len(schedule) < TARGET_WEEKS:
# Determine start date for the new entry
if not schedule:
# Start with the most recent Wednesday if list is empty
next_date = get_last_wednesday()
# Start with the first member alphabetically if list is empty
next_user = members[0]
else:
last_entry = schedule[-1]
last_user = last_entry['user']
# Parse last date and add 7 days
try:
last_date = datetime.strptime(last_entry['date'], "%Y-%m-%d").date()
next_date = last_date + timedelta(days=7)
except ValueError:
# Fallback if date is invalid/placeholder
next_date = get_last_wednesday() + timedelta(days=7 * len(schedule))
try:
# Find index of last scheduled user in the team list
if last_user in members:
last_idx = members.index(last_user)
next_idx = (last_idx + 1) % len(members)
next_user = members[next_idx]
else:
# Last user not in team, just pick first member
next_user = members[0]
except ValueError:
next_user = members[0]
new_entry = {"user": next_user, "date": next_date.strftime("%Y-%m-%d")}
schedule.append(new_entry)
print(f"Appended: {new_entry}")
def assign_reviewer(pr_number):
"""Assigns the mcore-oncall team as the reviewer for the PR."""
owner, repo = get_repo_info()
url = f"{GITHUB_API_URL}/repos/{owner}/{repo}/pulls/{pr_number}/requested_reviewers"
# Assign the oncall team as reviewer
data = {"team_reviewers": [ACTIVE_ONCALL_TEAM_SLUG]}
resp = requests.post(url, headers=get_headers(), json=data)
if resp.status_code in [201, 200]:
print(f"Successfully requested review from team NVIDIA/{ACTIVE_ONCALL_TEAM_SLUG}")
else:
print(f"Failed to request review: {resp.status_code} {resp.text}")
sys.exit(1)
def main():
parser = argparse.ArgumentParser(description="Manage Oncall Schedule")
subparsers = parser.add_subparsers(dest="command", required=True)
# Rotate command
parser_rotate = subparsers.add_parser("rotate", help="Rotate the schedule (remove first, append new)")
parser_rotate.add_argument("--dry-run", action="store_true", help="Do not save changes")
# Fill command (just fill up to 12 without rotating - useful for init)
parser_fill = subparsers.add_parser("fill", help="Fill the schedule to 12 weeks without rotating")
# Assign command
parser_assign = subparsers.add_parser("assign", help="Assign current oncall to PR")
parser_assign.add_argument("--pr", type=int, required=True, help="PR number")
args = parser.parse_args()
owner, _ = get_repo_info()
if args.command == "rotate":
rotate_schedule(owner, dry_run=args.dry_run)
elif args.command == "fill":
schedule = load_schedule()
ensure_schedule_filled(schedule, owner)
save_schedule(schedule)
print("Schedule filled and saved.")
elif args.command == "assign":
assign_reviewer(args.pr)
if __name__ == "__main__":
main()