| | |
| |
|
| | import ast |
| | import sys |
| | import os |
| | import re |
| | import subprocess |
| | import tempfile |
| | from typing import List, Optional, Pattern |
| |
|
| | RELEASE_PATTERN = re.compile(r"release_[0-9]+(_docs)*") |
| | |
| | |
| | |
| | PIP_INSTALL_PATTERN = re.compile( |
| | r"(python -m )?pip3* install (?P<quiet>-q )?(?P<package>mlagents(_envs)?)(==[0-9]+\.[0-9]+\.[0-9]+(\.dev[0-9]+)?)?" |
| | ) |
| | TRAINER_INIT_FILE = "ml-agents/mlagents/trainers/__init__.py" |
| |
|
| | MATCH_ANY = re.compile(r"(?s).*") |
| | |
| | |
| | ALLOW_LIST = { |
| | |
| | "docs/Python-PettingZoo-API.md": re.compile( |
| | r"\*\*(Verified Package ([0-9]\.?)*|Release [0-9]+)\*\*" |
| | ), |
| | "docs/Versioning.md": MATCH_ANY, |
| | "com.unity.ml-agents/CHANGELOG.md": MATCH_ANY, |
| | "utils/make_readme_table.py": MATCH_ANY, |
| | "utils/validate_release_links.py": MATCH_ANY, |
| | } |
| |
|
| |
|
| | def test_release_pattern(): |
| | |
| | for s, expected in [ |
| | ( |
| | "https://github.com/Unity-Technologies/ml-agents/blob/release_4_docs/Food.md", |
| | True, |
| | ), |
| | ("https://github.com/Unity-Technologies/ml-agents/blob/release_4/Foo.md", True), |
| | ( |
| | "git clone --branch release_4 https://github.com/Unity-Technologies/ml-agents.git", |
| | True, |
| | ), |
| | ( |
| | "https://github.com/Unity-Technologies/ml-agents/blob/release_123_docs/Foo.md", |
| | True, |
| | ), |
| | ( |
| | "https://github.com/Unity-Technologies/ml-agents/blob/release_123/Foo.md", |
| | True, |
| | ), |
| | ( |
| | "https://github.com/Unity-Technologies/ml-agents/blob/latest_release/docs/Foo.md", |
| | False, |
| | ), |
| | ]: |
| | assert bool(RELEASE_PATTERN.search(s)) is expected |
| |
|
| | print("release tests OK!") |
| |
|
| |
|
| | def test_pip_pattern(): |
| | |
| | for s, expected in [ |
| | ("pip install mlagents", True), |
| | ("pip3 install -q mlagents", True), |
| | ("python -m pip install mlagents", True), |
| | ("python -m pip install mlagents==1.2.3", True), |
| | ("python -m pip install mlagents_envs==1.2.3", True), |
| | ("python -m pip install mlagents==11.222.3333", True), |
| | ("python -m pip install mlagents_envs==11.222.3333", True), |
| | ]: |
| | assert bool(PIP_INSTALL_PATTERN.search(s)) is expected |
| |
|
| | sub_expected = "Try running rm -rf / to install" |
| | assert sub_expected == PIP_INSTALL_PATTERN.sub( |
| | "rm -rf /", "Try running python -m pip install mlagents==1.2.3 to install" |
| | ) |
| |
|
| | print("pip tests OK!") |
| |
|
| |
|
| | def update_pip_install_line(line, package_verion): |
| | match = PIP_INSTALL_PATTERN.search(line) |
| | if match is not None: |
| | package_name = match.group("package") |
| | quiet_option = match.group("quiet") or "" |
| | replacement_version = ( |
| | f"python -m pip install {quiet_option}{package_name}=={package_verion}" |
| | ) |
| | updated = PIP_INSTALL_PATTERN.sub(replacement_version, line) |
| | return updated |
| | else: |
| | return line |
| |
|
| |
|
| | def git_ls_files() -> List[str]: |
| | """ |
| | Run "git ls-files" and return a list with one entry per line. |
| | This returns the list of all files tracked by git. |
| | """ |
| | return subprocess.check_output(["git", "ls-files"], universal_newlines=True).split( |
| | "\n" |
| | ) |
| |
|
| |
|
| | def get_release_tag() -> Optional[str]: |
| | """ |
| | Returns the release tag for the mlagents python package. |
| | This will be None on the main branch. |
| | :return: |
| | """ |
| | with open(TRAINER_INIT_FILE) as f: |
| | for line in f: |
| | if "__release_tag__" in line: |
| | lhs, equals_string, rhs = line.strip().partition(" = ") |
| | |
| | return ast.literal_eval(rhs) |
| | |
| | |
| | raise RuntimeError("Can't determine release tag") |
| |
|
| |
|
| | def get_python_package_version() -> str: |
| | """ |
| | Returns the mlagents python package. |
| | :return: |
| | """ |
| | with open(TRAINER_INIT_FILE) as f: |
| | for line in f: |
| | if "__version__" in line: |
| | lhs, equals_string, rhs = line.strip().partition(" = ") |
| | |
| | return ast.literal_eval(rhs) |
| | |
| | |
| | raise RuntimeError("Can't determine python package version") |
| |
|
| |
|
| | def check_file( |
| | filename: str, |
| | release_tag_pattern: Pattern, |
| | release_tag: str, |
| | pip_allow_pattern: Pattern, |
| | package_version: str, |
| | ) -> List[str]: |
| | """ |
| | Validate a single file and return any offending lines. |
| | """ |
| | bad_lines = [] |
| | with tempfile.TemporaryDirectory() as tempdir: |
| | if not os.path.exists(tempdir): |
| | os.makedirs(tempdir) |
| | new_file_name = os.path.join(tempdir, os.path.basename(filename)) |
| | with open(new_file_name, "w+") as new_file: |
| | |
| | allow_list_pattern = ALLOW_LIST.get(filename, None) |
| | with open(filename) as f: |
| | for line in f: |
| | |
| | has_release_pattern = RELEASE_PATTERN.search(line) is not None |
| | |
| | has_release_tag_pattern = ( |
| | release_tag_pattern.search(line) is not None |
| | ) |
| | |
| | has_allow_list_pattern = ( |
| | allow_list_pattern |
| | and allow_list_pattern.search(line) is not None |
| | ) |
| |
|
| | pip_install_ok = ( |
| | has_allow_list_pattern |
| | or PIP_INSTALL_PATTERN.search(line) is None |
| | or pip_allow_pattern.search(line) is not None |
| | ) |
| |
|
| | release_tag_ok = ( |
| | not has_release_pattern |
| | or has_release_tag_pattern |
| | or has_allow_list_pattern |
| | ) |
| |
|
| | if release_tag_ok and pip_install_ok: |
| | new_file.write(line) |
| | else: |
| | bad_lines.append(f"{filename}: {line}") |
| | new_line = re.sub(r"release_[0-9]+", rf"{release_tag}", line) |
| | new_line = update_pip_install_line(new_line, package_version) |
| | new_file.write(new_line) |
| | if bad_lines: |
| | if os.path.exists(filename): |
| | os.remove(filename) |
| | os.rename(new_file_name, filename) |
| |
|
| | return bad_lines |
| |
|
| |
|
| | def check_all_files( |
| | release_allow_pattern: Pattern, |
| | release_tag: str, |
| | pip_allow_pattern: Pattern, |
| | package_version: str, |
| | ) -> List[str]: |
| | """ |
| | Validate all files tracked by git. |
| | :param release_allow_pattern: |
| | """ |
| | bad_lines = [] |
| | file_types = {".py", ".md", ".cs", ".ipynb"} |
| | for file_name in git_ls_files(): |
| | if "localized" in file_name or os.path.splitext(file_name)[1] not in file_types: |
| | continue |
| | bad_lines += check_file( |
| | file_name, |
| | release_allow_pattern, |
| | release_tag, |
| | pip_allow_pattern, |
| | package_version, |
| | ) |
| | return bad_lines |
| |
|
| |
|
| | def main(): |
| | release_tag = get_release_tag() |
| | if not release_tag: |
| | print("Release tag is None, exiting") |
| | sys.exit(0) |
| |
|
| | package_version = get_python_package_version() |
| | print(f"Release tag: {release_tag}") |
| | print(f"Python package version: {package_version}") |
| | release_allow_pattern = re.compile(f"{release_tag}(_docs)?") |
| | pip_allow_pattern = re.compile( |
| | rf"python -m pip install (-q )?mlagents(_envs)?=={package_version}" |
| | ) |
| | bad_lines = check_all_files( |
| | release_allow_pattern, release_tag, pip_allow_pattern, package_version |
| | ) |
| | if bad_lines: |
| | for line in bad_lines: |
| | print(line) |
| |
|
| | print("*************************************************************") |
| | print( |
| | "This script attempted to fix the above errors. Please double " |
| | + "check them to make sure the replacements were done correctly" |
| | ) |
| |
|
| | sys.exit(1 if bad_lines else 0) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | if "--test" in sys.argv: |
| | test_release_pattern() |
| | test_pip_pattern() |
| | main() |
| |
|