Instructions to use KexuanShi/Megatron-LM with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- NeMo
How to use KexuanShi/Megatron-LM with NeMo:
# tag did not correspond to a valid NeMo domain.
- Notebooks
- Google Colab
- Kaggle
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| import sys | |
| import json | |
| import requests | |
| import argparse | |
| from datetime import datetime, timedelta, timezone | |
| from slack_sdk import WebClient | |
| from slack_sdk.errors import SlackApiError | |
| # Constants | |
| GITHUB_API_URL = "https://api.github.com" | |
| SCHEDULE_FILE = ".github/oncall_schedule.json" | |
| ROTATION_TEAM_SLUG = "mcore-oncall-rotation" | |
| ACTIVE_ONCALL_TEAM_SLUG = "mcore-oncall" | |
| SLACK_USERGROUP_HANDLE = "mcore-oncall" | |
| TARGET_WEEKS = 12 | |
| # Caches for email and Slack lookups | |
| _email_cache = {} | |
| _slack_id_cache = {} | |
| def get_headers(): | |
| token = os.environ.get("GH_TOKEN") | |
| if not token: | |
| # Fallback to GITHUB_TOKEN if GH_TOKEN not set | |
| token = os.environ.get("GITHUB_TOKEN") | |
| if not token: | |
| print("Error: GH_TOKEN or GITHUB_TOKEN not set") | |
| sys.exit(1) | |
| return { | |
| "Authorization": f"token {token}", | |
| "Accept": "application/vnd.github.v3+json" | |
| } | |
| def get_repo_info(): | |
| """Returns (owner, repo) from GITHUB_REPOSITORY env var.""" | |
| repo_env = os.environ.get("GITHUB_REPOSITORY") | |
| if not repo_env: | |
| print("Error: GITHUB_REPOSITORY environment variable not set") | |
| sys.exit(1) | |
| parts = repo_env.split("/") | |
| return parts[0], parts[1] | |
| def get_team_members(org, team_slug): | |
| """Fetches members of the GitHub team.""" | |
| url = f"{GITHUB_API_URL}/orgs/{org}/teams/{team_slug}/members" | |
| headers = get_headers() | |
| members = set() | |
| page = 1 | |
| while True: | |
| resp = requests.get(f"{url}?per_page=100&page={page}", headers=headers) | |
| if resp.status_code != 200: | |
| print(f"Error fetching team members: {resp.status_code} {resp.text}") | |
| sys.exit(1) | |
| data = resp.json() | |
| if not data: | |
| break | |
| members.update([m['login'] for m in data]) | |
| if len(data) < 100: | |
| break | |
| page += 1 | |
| return members | |
| def get_user_email(username): | |
| """Get user's email from GitHub, prioritizing @nvidia.com emails. | |
| Checks in order: | |
| 1. Public profile email | |
| 2. Recent commits in the repository | |
| """ | |
| if username in _email_cache: | |
| return _email_cache[username] | |
| headers = get_headers() | |
| public_email = None | |
| try: | |
| # 1. Try to get user's public profile email first | |
| resp = requests.get(f"{GITHUB_API_URL}/users/{username}", headers=headers) | |
| if resp.status_code == 200: | |
| user_data = resp.json() | |
| email = user_data.get('email') | |
| if email and not email.endswith("@users.noreply.github.com"): | |
| if email.endswith("@nvidia.com"): | |
| _email_cache[username] = email | |
| return email | |
| # Store non-nvidia email as fallback | |
| public_email = email | |
| # 2. Check recent commits in the repository for @nvidia.com email | |
| repo_env = os.environ.get("GITHUB_REPOSITORY", "NVIDIA/Megatron-LM") | |
| commits_url = f"{GITHUB_API_URL}/repos/{repo_env}/commits?author={username}&per_page=10" | |
| resp = requests.get(commits_url, headers=headers) | |
| if resp.status_code == 200: | |
| commits = resp.json() | |
| for commit in commits: | |
| # Get email from commit author | |
| commit_data = commit.get('commit', {}) | |
| author_data = commit_data.get('author', {}) | |
| email = author_data.get('email') | |
| if email and not email.endswith("@users.noreply.github.com"): | |
| if email.endswith("@nvidia.com"): | |
| _email_cache[username] = email | |
| print(f"Found @nvidia.com email for {username} from commits: {email}") | |
| return email | |
| elif public_email is None: | |
| public_email = email | |
| # 3. Use public email if found, otherwise fallback | |
| if public_email: | |
| _email_cache[username] = public_email | |
| print(f"Using public email for {username}: {public_email}") | |
| return public_email | |
| # Fallback to noreply email | |
| fallback = f"{username}@users.noreply.github.com" | |
| _email_cache[username] = fallback | |
| print(f"Warning: No email found for {username}, using fallback: {fallback}") | |
| return fallback | |
| except Exception as e: | |
| print(f"Warning: Could not get email for {username}: {e}") | |
| fallback = f"{username}@users.noreply.github.com" | |
| _email_cache[username] = fallback | |
| return fallback | |
| def get_slack_client(): | |
| """Get Slack WebClient if token is available.""" | |
| slack_token = os.environ.get("SLACK_TOKEN") | |
| if not slack_token: | |
| return None | |
| return WebClient(token=slack_token) | |
| def get_slack_user_id(slack_client, email): | |
| """Get Slack user ID from email.""" | |
| if not slack_client: | |
| return None | |
| if email in _slack_id_cache: | |
| return _slack_id_cache[email] | |
| try: | |
| response = slack_client.users_lookupByEmail(email=email) | |
| user_id = response["user"]["id"] | |
| _slack_id_cache[email] = user_id | |
| return user_id | |
| except SlackApiError as e: | |
| print(f"Warning: Could not find Slack user for {email}: {e.response['error']}") | |
| _slack_id_cache[email] = None | |
| return None | |
| def get_slack_usergroup_id(slack_client, handle): | |
| """Get Slack usergroup ID from handle.""" | |
| if not slack_client: | |
| return None | |
| try: | |
| response = slack_client.usergroups_list(include_users=True) | |
| for usergroup in response.get("usergroups", []): | |
| if usergroup.get("handle") == handle: | |
| return usergroup.get("id"), usergroup.get("users", []) | |
| print(f"Warning: Slack usergroup '{handle}' not found") | |
| return None, [] | |
| except SlackApiError as e: | |
| print(f"Warning: Could not list Slack usergroups: {e.response['error']}") | |
| return None, [] | |
| def update_slack_usergroup(new_oncall_username, old_members_usernames): | |
| """ | |
| Updates the Slack usergroup to contain only the new oncall user. | |
| Adds new oncall first, then removes old members (usergroups need at least one member). | |
| """ | |
| slack_client = get_slack_client() | |
| if not slack_client: | |
| print("Slack token not configured, skipping Slack usergroup update") | |
| return | |
| # Get the new oncall's email and Slack user ID | |
| new_email = get_user_email(new_oncall_username) | |
| new_slack_id = get_slack_user_id(slack_client, new_email) | |
| if not new_slack_id: | |
| print(f"Could not find Slack user ID for {new_oncall_username} ({new_email}), skipping Slack update") | |
| return | |
| # Get the usergroup ID and current members | |
| usergroup_id, current_slack_members = get_slack_usergroup_id(slack_client, SLACK_USERGROUP_HANDLE) | |
| if not usergroup_id: | |
| print(f"Could not find Slack usergroup '{SLACK_USERGROUP_HANDLE}', skipping Slack update") | |
| return | |
| try: | |
| # Step 1: Add new oncall first (include current members to avoid removing anyone yet) | |
| # This ensures usergroup always has at least one member | |
| if new_slack_id not in current_slack_members: | |
| updated_members = list(set(current_slack_members + [new_slack_id])) | |
| slack_client.usergroups_users_update( | |
| usergroup=usergroup_id, | |
| users=updated_members | |
| ) | |
| print(f"Added {new_oncall_username} to Slack usergroup '{SLACK_USERGROUP_HANDLE}'") | |
| # Step 2: Now set the usergroup to contain only the new oncall | |
| slack_client.usergroups_users_update( | |
| usergroup=usergroup_id, | |
| users=[new_slack_id] | |
| ) | |
| print(f"Updated Slack usergroup '{SLACK_USERGROUP_HANDLE}' to contain only {new_oncall_username}") | |
| except SlackApiError as e: | |
| print(f"Failed to update Slack usergroup: {e.response['error']}") | |
| def load_schedule(): | |
| if not os.path.exists(SCHEDULE_FILE): | |
| return [] | |
| try: | |
| with open(SCHEDULE_FILE, 'r') as f: | |
| data = json.load(f) | |
| # Normalize to list of dicts if it's a list of strings | |
| schedule = [] | |
| for item in data: | |
| if isinstance(item, str): | |
| schedule.append({"user": item, "date": "YYYY-MM-DD"}) | |
| else: | |
| schedule.append(item) | |
| return schedule | |
| except (json.JSONDecodeError, FileNotFoundError): | |
| return [] | |
| def save_schedule(schedule): | |
| with open(SCHEDULE_FILE, 'w') as f: | |
| json.dump(schedule, f, indent=4) | |
| f.write('\n') # trailing newline | |
| def update_active_oncall_team(org, new_oncall): | |
| """Updates the active oncall team to contain only the new oncall user.""" | |
| # 1. Get current members of the active team | |
| current_members = get_team_members(org, ACTIVE_ONCALL_TEAM_SLUG) | |
| # 2. Add the new oncall if not present | |
| if new_oncall not in current_members: | |
| url = f"{GITHUB_API_URL}/orgs/{org}/teams/{ACTIVE_ONCALL_TEAM_SLUG}/memberships/{new_oncall}" | |
| resp = requests.put(url, headers=get_headers()) | |
| if resp.status_code == 200: | |
| print(f"Added {new_oncall} to {ACTIVE_ONCALL_TEAM_SLUG}") | |
| else: | |
| print(f"Failed to add {new_oncall} to {ACTIVE_ONCALL_TEAM_SLUG}: {resp.status_code} {resp.text}") | |
| # 3. Remove everyone else | |
| old_members = [] | |
| for member in current_members: | |
| if member not in [new_oncall, 'svcnvidia-nemo-ci']: | |
| old_members.append(member) | |
| url = f"{GITHUB_API_URL}/orgs/{org}/teams/{ACTIVE_ONCALL_TEAM_SLUG}/memberships/{member}" | |
| resp = requests.delete(url, headers=get_headers()) | |
| if resp.status_code == 204: | |
| print(f"Removed {member} from {ACTIVE_ONCALL_TEAM_SLUG}") | |
| else: | |
| print(f"Failed to remove {member} from {ACTIVE_ONCALL_TEAM_SLUG}: {resp.status_code} {resp.text}") | |
| # 4. Update Slack usergroup (add new oncall first, then remove old members) | |
| update_slack_usergroup(new_oncall, old_members) | |
| def rotate_schedule(repo_owner, dry_run=False): | |
| schedule = load_schedule() | |
| print(f"Current schedule length: {len(schedule)}") | |
| # 1. Rotate (Remove past week) | |
| # Only if schedule is not empty. | |
| if schedule: | |
| # Check date of first entry | |
| first_entry = schedule[0] | |
| try: | |
| # We assume the date is the *start* of the oncall shift (Wednesday). | |
| # The shift ends 7 days later. | |
| start_date = datetime.strptime(first_entry['date'], "%Y-%m-%d").date() | |
| end_date = start_date + timedelta(days=7) | |
| today = datetime.now(timezone.utc).date() | |
| # If today is >= end_date, the shift is over. | |
| # (e.g. Started last Wed, ends today Wed. If today is Wed, we rotate) | |
| if today >= end_date: | |
| removed = schedule.pop(0) | |
| print(f"Rotated out: {removed} (Ended {end_date})") | |
| else: | |
| print(f"First entry {first_entry} has not ended yet (Ends {end_date}). Not removing.") | |
| except ValueError: | |
| # Fallback if date is invalid, rotate anyway | |
| removed = schedule.pop(0) | |
| print(f"Rotated out (invalid date): {removed}") | |
| else: | |
| print("Schedule empty, nothing to rotate.") | |
| # 2. Replenish | |
| ensure_schedule_filled(schedule, repo_owner) | |
| # 3. Update active oncall team | |
| if schedule: | |
| current_oncall = schedule[0]['user'] | |
| print(f"New active oncall: {current_oncall}") | |
| if not dry_run: | |
| update_active_oncall_team(repo_owner, current_oncall) | |
| else: | |
| print(f"Dry run: Would update {ACTIVE_ONCALL_TEAM_SLUG} to contain only {current_oncall}") | |
| if not dry_run: | |
| save_schedule(schedule) | |
| print("Schedule updated and saved.") | |
| else: | |
| print("Dry run: Schedule not saved.") | |
| print(json.dumps(schedule, indent=4)) | |
| def get_last_wednesday(): | |
| today = datetime.now(timezone.utc).date() | |
| # Monday=0, Wednesday=2 | |
| offset = (today.weekday() - 2) % 7 | |
| return today - timedelta(days=offset) | |
| def ensure_schedule_filled(schedule, repo_owner): | |
| """Appends users to schedule until it reaches TARGET_WEEKS.""" | |
| members = get_team_members(repo_owner, ROTATION_TEAM_SLUG) | |
| if not members: | |
| print(f"Warning: No team members found in {ROTATION_TEAM_SLUG}.") | |
| return | |
| if 'svcnvidia-nemo-ci' in members: | |
| members.remove('svcnvidia-nemo-ci') | |
| members = list(members) | |
| members.sort() # Deterministic order | |
| while len(schedule) < TARGET_WEEKS: | |
| # Determine start date for the new entry | |
| if not schedule: | |
| # Start with the most recent Wednesday if list is empty | |
| next_date = get_last_wednesday() | |
| # Start with the first member alphabetically if list is empty | |
| next_user = members[0] | |
| else: | |
| last_entry = schedule[-1] | |
| last_user = last_entry['user'] | |
| # Parse last date and add 7 days | |
| try: | |
| last_date = datetime.strptime(last_entry['date'], "%Y-%m-%d").date() | |
| next_date = last_date + timedelta(days=7) | |
| except ValueError: | |
| # Fallback if date is invalid/placeholder | |
| next_date = get_last_wednesday() + timedelta(days=7 * len(schedule)) | |
| try: | |
| # Find index of last scheduled user in the team list | |
| if last_user in members: | |
| last_idx = members.index(last_user) | |
| next_idx = (last_idx + 1) % len(members) | |
| next_user = members[next_idx] | |
| else: | |
| # Last user not in team, just pick first member | |
| next_user = members[0] | |
| except ValueError: | |
| next_user = members[0] | |
| new_entry = {"user": next_user, "date": next_date.strftime("%Y-%m-%d")} | |
| schedule.append(new_entry) | |
| print(f"Appended: {new_entry}") | |
| def assign_reviewer(pr_number): | |
| """Assigns the mcore-oncall team as the reviewer for the PR.""" | |
| owner, repo = get_repo_info() | |
| url = f"{GITHUB_API_URL}/repos/{owner}/{repo}/pulls/{pr_number}/requested_reviewers" | |
| # Assign the oncall team as reviewer | |
| data = {"team_reviewers": [ACTIVE_ONCALL_TEAM_SLUG]} | |
| resp = requests.post(url, headers=get_headers(), json=data) | |
| if resp.status_code in [201, 200]: | |
| print(f"Successfully requested review from team NVIDIA/{ACTIVE_ONCALL_TEAM_SLUG}") | |
| else: | |
| print(f"Failed to request review: {resp.status_code} {resp.text}") | |
| sys.exit(1) | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Manage Oncall Schedule") | |
| subparsers = parser.add_subparsers(dest="command", required=True) | |
| # Rotate command | |
| parser_rotate = subparsers.add_parser("rotate", help="Rotate the schedule (remove first, append new)") | |
| parser_rotate.add_argument("--dry-run", action="store_true", help="Do not save changes") | |
| # Fill command (just fill up to 12 without rotating - useful for init) | |
| parser_fill = subparsers.add_parser("fill", help="Fill the schedule to 12 weeks without rotating") | |
| # Assign command | |
| parser_assign = subparsers.add_parser("assign", help="Assign current oncall to PR") | |
| parser_assign.add_argument("--pr", type=int, required=True, help="PR number") | |
| args = parser.parse_args() | |
| owner, _ = get_repo_info() | |
| if args.command == "rotate": | |
| rotate_schedule(owner, dry_run=args.dry_run) | |
| elif args.command == "fill": | |
| schedule = load_schedule() | |
| ensure_schedule_filled(schedule, owner) | |
| save_schedule(schedule) | |
| print("Schedule filled and saved.") | |
| elif args.command == "assign": | |
| assign_reviewer(args.pr) | |
| if __name__ == "__main__": | |
| main() | |