Add files using upload-large-folder tool
Browse files- .gitattributes +12 -0
- pythonProject/.venv/Scripts/isympy.exe +3 -0
- pythonProject/.venv/Scripts/numpy-config.exe +3 -0
- pythonProject/.venv/Scripts/pip.exe +3 -0
- pythonProject/.venv/Scripts/pip3.10.exe +3 -0
- pythonProject/.venv/Scripts/pip3.exe +3 -0
- pythonProject/.venv/Scripts/python.exe +3 -0
- pythonProject/.venv/Scripts/pythonw.exe +3 -0
- pythonProject/.venv/Scripts/tiny-agents.exe +3 -0
- pythonProject/.venv/Scripts/torchfrtrace.exe +3 -0
- pythonProject/.venv/Scripts/torchrun.exe +3 -0
- pythonProject/.venv/Scripts/tqdm.exe +3 -0
- pythonProject/diffusers-main/docs/source/en/imgs/access_request.png +3 -0
- pythonProject/diffusers-main/utils/notify_benchmarking_status.py +56 -0
- pythonProject/diffusers-main/utils/notify_community_pipelines_mirror.py +54 -0
- pythonProject/diffusers-main/utils/notify_slack_about_release.py +81 -0
- pythonProject/diffusers-main/utils/overwrite_expected_slice.py +90 -0
- pythonProject/diffusers-main/utils/print_env.py +73 -0
- pythonProject/diffusers-main/utils/release.py +162 -0
- pythonProject/diffusers-main/utils/stale.py +69 -0
- pythonProject/diffusers-main/utils/tests_fetcher.py +1128 -0
.gitattributes
CHANGED
|
@@ -266,3 +266,15 @@ pythonProject/.venv/Scripts/diffusers-cli.exe filter=lfs diff=lfs merge=lfs -tex
|
|
| 266 |
pythonProject/.venv/Scripts/check-node.exe filter=lfs diff=lfs merge=lfs -text
|
| 267 |
pythonProject/.venv/Scripts/hf.exe filter=lfs diff=lfs merge=lfs -text
|
| 268 |
pythonProject/.venv/Scripts/huggingface-cli.exe filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
pythonProject/.venv/Scripts/check-node.exe filter=lfs diff=lfs merge=lfs -text
|
| 267 |
pythonProject/.venv/Scripts/hf.exe filter=lfs diff=lfs merge=lfs -text
|
| 268 |
pythonProject/.venv/Scripts/huggingface-cli.exe filter=lfs diff=lfs merge=lfs -text
|
| 269 |
+
pythonProject/.venv/Scripts/numpy-config.exe filter=lfs diff=lfs merge=lfs -text
|
| 270 |
+
pythonProject/.venv/Scripts/isympy.exe filter=lfs diff=lfs merge=lfs -text
|
| 271 |
+
pythonProject/.venv/Scripts/pip3.10.exe filter=lfs diff=lfs merge=lfs -text
|
| 272 |
+
pythonProject/.venv/Scripts/pip.exe filter=lfs diff=lfs merge=lfs -text
|
| 273 |
+
pythonProject/.venv/Scripts/pip3.exe filter=lfs diff=lfs merge=lfs -text
|
| 274 |
+
pythonProject/.venv/Scripts/python.exe filter=lfs diff=lfs merge=lfs -text
|
| 275 |
+
pythonProject/.venv/Scripts/tiny-agents.exe filter=lfs diff=lfs merge=lfs -text
|
| 276 |
+
pythonProject/.venv/Scripts/pythonw.exe filter=lfs diff=lfs merge=lfs -text
|
| 277 |
+
pythonProject/.venv/Scripts/torchfrtrace.exe filter=lfs diff=lfs merge=lfs -text
|
| 278 |
+
pythonProject/.venv/Scripts/torchrun.exe filter=lfs diff=lfs merge=lfs -text
|
| 279 |
+
pythonProject/.venv/Scripts/tqdm.exe filter=lfs diff=lfs merge=lfs -text
|
| 280 |
+
pythonProject/diffusers-main/docs/source/en/imgs/access_request.png filter=lfs diff=lfs merge=lfs -text
|
pythonProject/.venv/Scripts/isympy.exe
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7e0939da22ca6f0052870e8dc3da42a78679a46b194e96ed827be31c9a834471
|
| 3 |
+
size 108395
|
pythonProject/.venv/Scripts/numpy-config.exe
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c6c0c18062b265bbf63daa978ed69ff5cbb63cda8dcaae7787b563d61d0cd29e
|
| 3 |
+
size 108406
|
pythonProject/.venv/Scripts/pip.exe
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:df129c9009bd7ba09b466cbafbd196813830eef8c03f3009c1f30946bb377acb
|
| 3 |
+
size 108411
|
pythonProject/.venv/Scripts/pip3.10.exe
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:df129c9009bd7ba09b466cbafbd196813830eef8c03f3009c1f30946bb377acb
|
| 3 |
+
size 108411
|
pythonProject/.venv/Scripts/pip3.exe
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:df129c9009bd7ba09b466cbafbd196813830eef8c03f3009c1f30946bb377acb
|
| 3 |
+
size 108411
|
pythonProject/.venv/Scripts/python.exe
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b2c836c52cdf063180b9ee76f67ac42946101b79ac457f3494035a67c090d961
|
| 3 |
+
size 268568
|
pythonProject/.venv/Scripts/pythonw.exe
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1c37c6c3c14074b69b3b8543ed5edcf6666b876c8f73672e52b58f6bbc5dd3ee
|
| 3 |
+
size 257304
|
pythonProject/.venv/Scripts/tiny-agents.exe
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cfbbcb52bbb876ef485061f31d62ec773a51726d305409a0147aab6b4eaf2a9f
|
| 3 |
+
size 108421
|
pythonProject/.venv/Scripts/torchfrtrace.exe
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4b3d9da8a20090cb5417548b5098b6dba389755916c493e460a782fdb791c3e2
|
| 3 |
+
size 108419
|
pythonProject/.venv/Scripts/torchrun.exe
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f4d393a13ca96560946983d8f7ed998c83726bc94584ddfd92942bea7de12565
|
| 3 |
+
size 108410
|
pythonProject/.venv/Scripts/tqdm.exe
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e8b93dfc2cd0722449177b0dd541c8b5ce17d3f3153f6a0ee7c0778ce419cd3b
|
| 3 |
+
size 108397
|
pythonProject/diffusers-main/docs/source/en/imgs/access_request.png
ADDED
|
Git LFS Details
|
pythonProject/diffusers-main/utils/notify_benchmarking_status.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
import os
|
| 18 |
+
|
| 19 |
+
import requests
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# Configuration
|
| 23 |
+
GITHUB_REPO = "huggingface/diffusers"
|
| 24 |
+
GITHUB_RUN_ID = os.getenv("GITHUB_RUN_ID")
|
| 25 |
+
SLACK_WEBHOOK_URL = os.getenv("SLACK_WEBHOOK_URL")
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def main(args):
|
| 29 |
+
action_url = f"https://github.com/{GITHUB_REPO}/actions/runs/{GITHUB_RUN_ID}"
|
| 30 |
+
if args.status == "success":
|
| 31 |
+
hub_path = "https://huggingface.co/datasets/diffusers/benchmarks/blob/main/collated_results.csv"
|
| 32 |
+
message = (
|
| 33 |
+
"✅ New benchmark workflow successfully run.\n"
|
| 34 |
+
f"🕸️ GitHub Action URL: {action_url}.\n"
|
| 35 |
+
f"🤗 Check out the benchmarks here: {hub_path}."
|
| 36 |
+
)
|
| 37 |
+
else:
|
| 38 |
+
message = (
|
| 39 |
+
"❌ Something wrong happened in the benchmarking workflow.\n"
|
| 40 |
+
f"Check out the GitHub Action to know more: {action_url}."
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
payload = {"text": message}
|
| 44 |
+
response = requests.post(SLACK_WEBHOOK_URL, json=payload)
|
| 45 |
+
|
| 46 |
+
if response.status_code == 200:
|
| 47 |
+
print("Notification sent to Slack successfully.")
|
| 48 |
+
else:
|
| 49 |
+
print("Failed to send notification to Slack.")
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
if __name__ == "__main__":
|
| 53 |
+
parser = argparse.ArgumentParser()
|
| 54 |
+
parser.add_argument("--status", type=str, default="success", choices=["success", "failure"])
|
| 55 |
+
args = parser.parse_args()
|
| 56 |
+
main(args)
|
pythonProject/diffusers-main/utils/notify_community_pipelines_mirror.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
import os
|
| 18 |
+
|
| 19 |
+
import requests
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# Configuration
|
| 23 |
+
GITHUB_REPO = "huggingface/diffusers"
|
| 24 |
+
GITHUB_RUN_ID = os.getenv("GITHUB_RUN_ID")
|
| 25 |
+
SLACK_WEBHOOK_URL = os.getenv("SLACK_WEBHOOK_URL")
|
| 26 |
+
PATH_IN_REPO = os.getenv("PATH_IN_REPO")
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def main(args):
|
| 30 |
+
action_url = f"https://github.com/{GITHUB_REPO}/actions/runs/{GITHUB_RUN_ID}"
|
| 31 |
+
if args.status == "success":
|
| 32 |
+
hub_path = f"https://huggingface.co/datasets/diffusers/community-pipelines-mirror/tree/main/{PATH_IN_REPO}"
|
| 33 |
+
message = (
|
| 34 |
+
"✅ Community pipelines successfully mirrored.\n"
|
| 35 |
+
f"🕸️ GitHub Action URL: {action_url}.\n"
|
| 36 |
+
f"🤗 Hub location: {hub_path}."
|
| 37 |
+
)
|
| 38 |
+
else:
|
| 39 |
+
message = f"❌ Something wrong happened. Check out the GitHub Action to know more: {action_url}."
|
| 40 |
+
|
| 41 |
+
payload = {"text": message}
|
| 42 |
+
response = requests.post(SLACK_WEBHOOK_URL, json=payload)
|
| 43 |
+
|
| 44 |
+
if response.status_code == 200:
|
| 45 |
+
print("Notification sent to Slack successfully.")
|
| 46 |
+
else:
|
| 47 |
+
print("Failed to send notification to Slack.")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
if __name__ == "__main__":
|
| 51 |
+
parser = argparse.ArgumentParser()
|
| 52 |
+
parser.add_argument("--status", type=str, default="success", choices=["success", "failure"])
|
| 53 |
+
args = parser.parse_args()
|
| 54 |
+
main(args)
|
pythonProject/diffusers-main/utils/notify_slack_about_release.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
|
| 18 |
+
import requests
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# Configuration
|
| 22 |
+
LIBRARY_NAME = "diffusers"
|
| 23 |
+
GITHUB_REPO = "huggingface/diffusers"
|
| 24 |
+
SLACK_WEBHOOK_URL = os.getenv("SLACK_WEBHOOK_URL")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def check_pypi_for_latest_release(library_name):
|
| 28 |
+
"""Check PyPI for the latest release of the library."""
|
| 29 |
+
response = requests.get(f"https://pypi.org/pypi/{library_name}/json", timeout=60)
|
| 30 |
+
if response.status_code == 200:
|
| 31 |
+
data = response.json()
|
| 32 |
+
return data["info"]["version"]
|
| 33 |
+
else:
|
| 34 |
+
print("Failed to fetch library details from PyPI.")
|
| 35 |
+
return None
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_github_release_info(github_repo):
|
| 39 |
+
"""Fetch the latest release info from GitHub."""
|
| 40 |
+
url = f"https://api.github.com/repos/{github_repo}/releases/latest"
|
| 41 |
+
response = requests.get(url, timeout=60)
|
| 42 |
+
|
| 43 |
+
if response.status_code == 200:
|
| 44 |
+
data = response.json()
|
| 45 |
+
return {"tag_name": data["tag_name"], "url": data["html_url"], "release_time": data["published_at"]}
|
| 46 |
+
|
| 47 |
+
else:
|
| 48 |
+
print("Failed to fetch release info from GitHub.")
|
| 49 |
+
return None
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def notify_slack(webhook_url, library_name, version, release_info):
|
| 53 |
+
"""Send a notification to a Slack channel."""
|
| 54 |
+
message = (
|
| 55 |
+
f"🚀 New release for {library_name} available: version **{version}** 🎉\n"
|
| 56 |
+
f"📜 Release Notes: {release_info['url']}\n"
|
| 57 |
+
f"⏱️ Release time: {release_info['release_time']}"
|
| 58 |
+
)
|
| 59 |
+
payload = {"text": message}
|
| 60 |
+
response = requests.post(webhook_url, json=payload)
|
| 61 |
+
|
| 62 |
+
if response.status_code == 200:
|
| 63 |
+
print("Notification sent to Slack successfully.")
|
| 64 |
+
else:
|
| 65 |
+
print("Failed to send notification to Slack.")
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def main():
|
| 69 |
+
latest_version = check_pypi_for_latest_release(LIBRARY_NAME)
|
| 70 |
+
release_info = get_github_release_info(GITHUB_REPO)
|
| 71 |
+
parsed_version = release_info["tag_name"].replace("v", "")
|
| 72 |
+
|
| 73 |
+
if latest_version and release_info and latest_version == parsed_version:
|
| 74 |
+
notify_slack(SLACK_WEBHOOK_URL, LIBRARY_NAME, latest_version, release_info)
|
| 75 |
+
else:
|
| 76 |
+
print(f"{latest_version=}, {release_info=}, {parsed_version=}")
|
| 77 |
+
raise ValueError("There were some problems.")
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
if __name__ == "__main__":
|
| 81 |
+
main()
|
pythonProject/diffusers-main/utils/overwrite_expected_slice.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
import argparse
|
| 16 |
+
from collections import defaultdict
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def overwrite_file(file, class_name, test_name, correct_line, done_test):
|
| 20 |
+
_id = f"{file}_{class_name}_{test_name}"
|
| 21 |
+
done_test[_id] += 1
|
| 22 |
+
|
| 23 |
+
with open(file, "r") as f:
|
| 24 |
+
lines = f.readlines()
|
| 25 |
+
|
| 26 |
+
class_regex = f"class {class_name}("
|
| 27 |
+
test_regex = f"{4 * ' '}def {test_name}("
|
| 28 |
+
line_begin_regex = f"{8 * ' '}{correct_line.split()[0]}"
|
| 29 |
+
another_line_begin_regex = f"{16 * ' '}{correct_line.split()[0]}"
|
| 30 |
+
in_class = False
|
| 31 |
+
in_func = False
|
| 32 |
+
in_line = False
|
| 33 |
+
insert_line = False
|
| 34 |
+
count = 0
|
| 35 |
+
spaces = 0
|
| 36 |
+
|
| 37 |
+
new_lines = []
|
| 38 |
+
for line in lines:
|
| 39 |
+
if line.startswith(class_regex):
|
| 40 |
+
in_class = True
|
| 41 |
+
elif in_class and line.startswith(test_regex):
|
| 42 |
+
in_func = True
|
| 43 |
+
elif in_class and in_func and (line.startswith(line_begin_regex) or line.startswith(another_line_begin_regex)):
|
| 44 |
+
spaces = len(line.split(correct_line.split()[0])[0])
|
| 45 |
+
count += 1
|
| 46 |
+
|
| 47 |
+
if count == done_test[_id]:
|
| 48 |
+
in_line = True
|
| 49 |
+
|
| 50 |
+
if in_class and in_func and in_line:
|
| 51 |
+
if ")" not in line:
|
| 52 |
+
continue
|
| 53 |
+
else:
|
| 54 |
+
insert_line = True
|
| 55 |
+
|
| 56 |
+
if in_class and in_func and in_line and insert_line:
|
| 57 |
+
new_lines.append(f"{spaces * ' '}{correct_line}")
|
| 58 |
+
in_class = in_func = in_line = insert_line = False
|
| 59 |
+
else:
|
| 60 |
+
new_lines.append(line)
|
| 61 |
+
|
| 62 |
+
with open(file, "w") as f:
|
| 63 |
+
for line in new_lines:
|
| 64 |
+
f.write(line)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def main(correct, fail=None):
|
| 68 |
+
if fail is not None:
|
| 69 |
+
with open(fail, "r") as f:
|
| 70 |
+
test_failures = {l.strip() for l in f.readlines()}
|
| 71 |
+
else:
|
| 72 |
+
test_failures = None
|
| 73 |
+
|
| 74 |
+
with open(correct, "r") as f:
|
| 75 |
+
correct_lines = f.readlines()
|
| 76 |
+
|
| 77 |
+
done_tests = defaultdict(int)
|
| 78 |
+
for line in correct_lines:
|
| 79 |
+
file, class_name, test_name, correct_line = line.split("::")
|
| 80 |
+
if test_failures is None or "::".join([file, class_name, test_name]) in test_failures:
|
| 81 |
+
overwrite_file(file, class_name, test_name, correct_line, done_tests)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
if __name__ == "__main__":
|
| 85 |
+
parser = argparse.ArgumentParser()
|
| 86 |
+
parser.add_argument("--correct_filename", help="filename of tests with expected result")
|
| 87 |
+
parser.add_argument("--fail_filename", help="filename of test failures", type=str, default=None)
|
| 88 |
+
args = parser.parse_args()
|
| 89 |
+
|
| 90 |
+
main(args.correct_filename, args.fail_filename)
|
pythonProject/diffusers-main/utils/print_env.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
|
| 3 |
+
# coding=utf-8
|
| 4 |
+
# Copyright 2025 The HuggingFace Inc. team.
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
# this script dumps information about the environment
|
| 19 |
+
|
| 20 |
+
import os
|
| 21 |
+
import platform
|
| 22 |
+
import sys
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
| 26 |
+
|
| 27 |
+
print("Python version:", sys.version)
|
| 28 |
+
|
| 29 |
+
print("OS platform:", platform.platform())
|
| 30 |
+
print("OS architecture:", platform.machine())
|
| 31 |
+
try:
|
| 32 |
+
import psutil
|
| 33 |
+
|
| 34 |
+
vm = psutil.virtual_memory()
|
| 35 |
+
total_gb = vm.total / (1024**3)
|
| 36 |
+
available_gb = vm.available / (1024**3)
|
| 37 |
+
print(f"Total RAM: {total_gb:.2f} GB")
|
| 38 |
+
print(f"Available RAM: {available_gb:.2f} GB")
|
| 39 |
+
except ImportError:
|
| 40 |
+
pass
|
| 41 |
+
|
| 42 |
+
try:
|
| 43 |
+
import torch
|
| 44 |
+
|
| 45 |
+
print("Torch version:", torch.__version__)
|
| 46 |
+
print("Cuda available:", torch.cuda.is_available())
|
| 47 |
+
if torch.cuda.is_available():
|
| 48 |
+
print("Cuda version:", torch.version.cuda)
|
| 49 |
+
print("CuDNN version:", torch.backends.cudnn.version())
|
| 50 |
+
print("Number of GPUs available:", torch.cuda.device_count())
|
| 51 |
+
device_properties = torch.cuda.get_device_properties(0)
|
| 52 |
+
total_memory = device_properties.total_memory / (1024**3)
|
| 53 |
+
print(f"CUDA memory: {total_memory} GB")
|
| 54 |
+
|
| 55 |
+
print("XPU available:", hasattr(torch, "xpu") and torch.xpu.is_available())
|
| 56 |
+
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
| 57 |
+
print("XPU model:", torch.xpu.get_device_properties(0).name)
|
| 58 |
+
print("XPU compiler version:", torch.version.xpu)
|
| 59 |
+
print("Number of XPUs available:", torch.xpu.device_count())
|
| 60 |
+
device_properties = torch.xpu.get_device_properties(0)
|
| 61 |
+
total_memory = device_properties.total_memory / (1024**3)
|
| 62 |
+
print(f"XPU memory: {total_memory} GB")
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
except ImportError:
|
| 66 |
+
print("Torch version:", None)
|
| 67 |
+
|
| 68 |
+
try:
|
| 69 |
+
import transformers
|
| 70 |
+
|
| 71 |
+
print("transformers version:", transformers.__version__)
|
| 72 |
+
except ImportError:
|
| 73 |
+
print("transformers version:", None)
|
pythonProject/diffusers-main/utils/release.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
import os
|
| 18 |
+
import re
|
| 19 |
+
|
| 20 |
+
import packaging.version
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
PATH_TO_EXAMPLES = "examples/"
|
| 24 |
+
REPLACE_PATTERNS = {
|
| 25 |
+
"examples": (re.compile(r'^check_min_version\("[^"]+"\)\s*$', re.MULTILINE), 'check_min_version("VERSION")\n'),
|
| 26 |
+
"init": (re.compile(r'^__version__\s+=\s+"([^"]+)"\s*$', re.MULTILINE), '__version__ = "VERSION"\n'),
|
| 27 |
+
"setup": (re.compile(r'^(\s*)version\s*=\s*"[^"]+",', re.MULTILINE), r'\1version="VERSION",'),
|
| 28 |
+
"doc": (re.compile(r'^(\s*)release\s*=\s*"[^"]+"$', re.MULTILINE), 'release = "VERSION"\n'),
|
| 29 |
+
}
|
| 30 |
+
REPLACE_FILES = {
|
| 31 |
+
"init": "src/diffusers/__init__.py",
|
| 32 |
+
"setup": "setup.py",
|
| 33 |
+
}
|
| 34 |
+
README_FILE = "README.md"
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def update_version_in_file(fname, version, pattern):
|
| 38 |
+
"""Update the version in one file using a specific pattern."""
|
| 39 |
+
with open(fname, "r", encoding="utf-8", newline="\n") as f:
|
| 40 |
+
code = f.read()
|
| 41 |
+
re_pattern, replace = REPLACE_PATTERNS[pattern]
|
| 42 |
+
replace = replace.replace("VERSION", version)
|
| 43 |
+
code = re_pattern.sub(replace, code)
|
| 44 |
+
with open(fname, "w", encoding="utf-8", newline="\n") as f:
|
| 45 |
+
f.write(code)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def update_version_in_examples(version):
|
| 49 |
+
"""Update the version in all examples files."""
|
| 50 |
+
for folder, directories, fnames in os.walk(PATH_TO_EXAMPLES):
|
| 51 |
+
# Removing some of the folders with non-actively maintained examples from the walk
|
| 52 |
+
if "research_projects" in directories:
|
| 53 |
+
directories.remove("research_projects")
|
| 54 |
+
if "legacy" in directories:
|
| 55 |
+
directories.remove("legacy")
|
| 56 |
+
for fname in fnames:
|
| 57 |
+
if fname.endswith(".py"):
|
| 58 |
+
update_version_in_file(os.path.join(folder, fname), version, pattern="examples")
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def global_version_update(version, patch=False):
|
| 62 |
+
"""Update the version in all needed files."""
|
| 63 |
+
for pattern, fname in REPLACE_FILES.items():
|
| 64 |
+
update_version_in_file(fname, version, pattern)
|
| 65 |
+
if not patch:
|
| 66 |
+
update_version_in_examples(version)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def clean_main_ref_in_model_list():
|
| 70 |
+
"""Replace the links from main doc tp stable doc in the model list of the README."""
|
| 71 |
+
# If the introduction or the conclusion of the list change, the prompts may need to be updated.
|
| 72 |
+
_start_prompt = "🤗 Transformers currently provides the following architectures"
|
| 73 |
+
_end_prompt = "1. Want to contribute a new model?"
|
| 74 |
+
with open(README_FILE, "r", encoding="utf-8", newline="\n") as f:
|
| 75 |
+
lines = f.readlines()
|
| 76 |
+
|
| 77 |
+
# Find the start of the list.
|
| 78 |
+
start_index = 0
|
| 79 |
+
while not lines[start_index].startswith(_start_prompt):
|
| 80 |
+
start_index += 1
|
| 81 |
+
start_index += 1
|
| 82 |
+
|
| 83 |
+
index = start_index
|
| 84 |
+
# Update the lines in the model list.
|
| 85 |
+
while not lines[index].startswith(_end_prompt):
|
| 86 |
+
if lines[index].startswith("1."):
|
| 87 |
+
lines[index] = lines[index].replace(
|
| 88 |
+
"https://huggingface.co/docs/diffusers/main/model_doc",
|
| 89 |
+
"https://huggingface.co/docs/diffusers/model_doc",
|
| 90 |
+
)
|
| 91 |
+
index += 1
|
| 92 |
+
|
| 93 |
+
with open(README_FILE, "w", encoding="utf-8", newline="\n") as f:
|
| 94 |
+
f.writelines(lines)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def get_version():
|
| 98 |
+
"""Reads the current version in the __init__."""
|
| 99 |
+
with open(REPLACE_FILES["init"], "r") as f:
|
| 100 |
+
code = f.read()
|
| 101 |
+
default_version = REPLACE_PATTERNS["init"][0].search(code).groups()[0]
|
| 102 |
+
return packaging.version.parse(default_version)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def pre_release_work(patch=False):
|
| 106 |
+
"""Do all the necessary pre-release steps."""
|
| 107 |
+
# First let's get the default version: base version if we are in dev, bump minor otherwise.
|
| 108 |
+
default_version = get_version()
|
| 109 |
+
if patch and default_version.is_devrelease:
|
| 110 |
+
raise ValueError("Can't create a patch version from the dev branch, checkout a released version!")
|
| 111 |
+
if default_version.is_devrelease:
|
| 112 |
+
default_version = default_version.base_version
|
| 113 |
+
elif patch:
|
| 114 |
+
default_version = f"{default_version.major}.{default_version.minor}.{default_version.micro + 1}"
|
| 115 |
+
else:
|
| 116 |
+
default_version = f"{default_version.major}.{default_version.minor + 1}.0"
|
| 117 |
+
|
| 118 |
+
# Now let's ask nicely if that's the right one.
|
| 119 |
+
version = input(f"Which version are you releasing? [{default_version}]")
|
| 120 |
+
if len(version) == 0:
|
| 121 |
+
version = default_version
|
| 122 |
+
|
| 123 |
+
print(f"Updating version to {version}.")
|
| 124 |
+
global_version_update(version, patch=patch)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
# if not patch:
|
| 128 |
+
# print("Cleaning main README, don't forget to run `make fix-copies`.")
|
| 129 |
+
# clean_main_ref_in_model_list()
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def post_release_work():
|
| 133 |
+
"""Do all the necessary post-release steps."""
|
| 134 |
+
# First let's get the current version
|
| 135 |
+
current_version = get_version()
|
| 136 |
+
dev_version = f"{current_version.major}.{current_version.minor + 1}.0.dev0"
|
| 137 |
+
current_version = current_version.base_version
|
| 138 |
+
|
| 139 |
+
# Check with the user we got that right.
|
| 140 |
+
version = input(f"Which version are we developing now? [{dev_version}]")
|
| 141 |
+
if len(version) == 0:
|
| 142 |
+
version = dev_version
|
| 143 |
+
|
| 144 |
+
print(f"Updating version to {version}.")
|
| 145 |
+
global_version_update(version)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
# print("Cleaning main README, don't forget to run `make fix-copies`.")
|
| 149 |
+
# clean_main_ref_in_model_list()
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
if __name__ == "__main__":
|
| 153 |
+
parser = argparse.ArgumentParser()
|
| 154 |
+
parser.add_argument("--post_release", action="store_true", help="Whether this is pre or post release.")
|
| 155 |
+
parser.add_argument("--patch", action="store_true", help="Whether or not this is a patch release.")
|
| 156 |
+
args = parser.parse_args()
|
| 157 |
+
if not args.post_release:
|
| 158 |
+
pre_release_work(patch=args.patch)
|
| 159 |
+
elif args.patch:
|
| 160 |
+
print("Nothing to do after a patch :-)")
|
| 161 |
+
else:
|
| 162 |
+
post_release_work()
|
pythonProject/diffusers-main/utils/stale.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team, the AllenNLP library authors. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""
|
| 15 |
+
Script to close stale issue. Taken in part from the AllenNLP repository.
|
| 16 |
+
https://github.com/allenai/allennlp.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import os
|
| 20 |
+
from datetime import datetime as dt
|
| 21 |
+
from datetime import timezone
|
| 22 |
+
|
| 23 |
+
from github import Github
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
LABELS_TO_EXEMPT = [
|
| 27 |
+
"close-to-merge",
|
| 28 |
+
"good first issue",
|
| 29 |
+
"good second issue",
|
| 30 |
+
"good difficult issue",
|
| 31 |
+
"enhancement",
|
| 32 |
+
"new pipeline/model",
|
| 33 |
+
"new scheduler",
|
| 34 |
+
"wip",
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def main():
|
| 39 |
+
g = Github(os.environ["GITHUB_TOKEN"])
|
| 40 |
+
repo = g.get_repo("huggingface/diffusers")
|
| 41 |
+
open_issues = repo.get_issues(state="open")
|
| 42 |
+
|
| 43 |
+
for issue in open_issues:
|
| 44 |
+
labels = [label.name.lower() for label in issue.get_labels()]
|
| 45 |
+
if "stale" in labels:
|
| 46 |
+
comments = sorted(issue.get_comments(), key=lambda i: i.created_at, reverse=True)
|
| 47 |
+
last_comment = comments[0] if len(comments) > 0 else None
|
| 48 |
+
if last_comment is not None and last_comment.user.login != "github-actions[bot]":
|
| 49 |
+
# Opens the issue if someone other than Stalebot commented.
|
| 50 |
+
issue.edit(state="open")
|
| 51 |
+
issue.remove_from_labels("stale")
|
| 52 |
+
elif (
|
| 53 |
+
(dt.now(timezone.utc) - issue.updated_at).days > 23
|
| 54 |
+
and (dt.now(timezone.utc) - issue.created_at).days >= 30
|
| 55 |
+
and not any(label in LABELS_TO_EXEMPT for label in labels)
|
| 56 |
+
):
|
| 57 |
+
# Post a Stalebot notification after 23 days of inactivity.
|
| 58 |
+
issue.create_comment(
|
| 59 |
+
"This issue has been automatically marked as stale because it has not had "
|
| 60 |
+
"recent activity. If you think this still needs to be addressed "
|
| 61 |
+
"please comment on this thread.\n\nPlease note that issues that do not follow the "
|
| 62 |
+
"[contributing guidelines](https://github.com/huggingface/diffusers/blob/main/CONTRIBUTING.md) "
|
| 63 |
+
"are likely to be ignored."
|
| 64 |
+
)
|
| 65 |
+
issue.add_to_labels("stale")
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
if __name__ == "__main__":
|
| 69 |
+
main()
|
pythonProject/diffusers-main/utils/tests_fetcher.py
ADDED
|
@@ -0,0 +1,1128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""
|
| 17 |
+
Welcome to tests_fetcher V2.
|
| 18 |
+
|
| 19 |
+
This util is designed to fetch tests to run on a PR so that only the tests impacted by the modifications are run, and
|
| 20 |
+
when too many models are being impacted, only run the tests of a subset of core models. It works like this.
|
| 21 |
+
|
| 22 |
+
Stage 1: Identify the modified files. For jobs that run on the main branch, it's just the diff with the last commit.
|
| 23 |
+
On a PR, this takes all the files from the branching point to the current commit (so all modifications in a PR, not
|
| 24 |
+
just the last commit) but excludes modifications that are on docstrings or comments only.
|
| 25 |
+
|
| 26 |
+
Stage 2: Extract the tests to run. This is done by looking at the imports in each module and test file: if module A
|
| 27 |
+
imports module B, then changing module B impacts module A, so the tests using module A should be run. We thus get the
|
| 28 |
+
dependencies of each model and then recursively builds the 'reverse' map of dependencies to get all modules and tests
|
| 29 |
+
impacted by a given file. We then only keep the tests (and only the core models tests if there are too many modules).
|
| 30 |
+
|
| 31 |
+
Caveats:
|
| 32 |
+
- This module only filters tests by files (not individual tests) so it's better to have tests for different things
|
| 33 |
+
in different files.
|
| 34 |
+
- This module assumes inits are just importing things, not really building objects, so it's better to structure
|
| 35 |
+
them this way and move objects building in separate submodules.
|
| 36 |
+
|
| 37 |
+
Usage:
|
| 38 |
+
|
| 39 |
+
Base use to fetch the tests in a pull request
|
| 40 |
+
|
| 41 |
+
```bash
|
| 42 |
+
python utils/tests_fetcher.py
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
Base use to fetch the tests on a the main branch (with diff from the last commit):
|
| 46 |
+
|
| 47 |
+
```bash
|
| 48 |
+
python utils/tests_fetcher.py --diff_with_last_commit
|
| 49 |
+
```
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
import argparse
|
| 53 |
+
import collections
|
| 54 |
+
import json
|
| 55 |
+
import os
|
| 56 |
+
import re
|
| 57 |
+
from contextlib import contextmanager
|
| 58 |
+
from pathlib import Path
|
| 59 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 60 |
+
|
| 61 |
+
from git import Repo
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
PATH_TO_REPO = Path(__file__).parent.parent.resolve()
|
| 65 |
+
PATH_TO_EXAMPLES = PATH_TO_REPO / "examples"
|
| 66 |
+
PATH_TO_DIFFUSERS = PATH_TO_REPO / "src/diffusers"
|
| 67 |
+
PATH_TO_TESTS = PATH_TO_REPO / "tests"
|
| 68 |
+
|
| 69 |
+
# Ignore fixtures in tests folder
|
| 70 |
+
# Ignore lora since they are always tested
|
| 71 |
+
MODULES_TO_IGNORE = ["fixtures", "lora"]
|
| 72 |
+
|
| 73 |
+
IMPORTANT_PIPELINES = [
|
| 74 |
+
"controlnet",
|
| 75 |
+
"stable_diffusion",
|
| 76 |
+
"stable_diffusion_2",
|
| 77 |
+
"stable_diffusion_xl",
|
| 78 |
+
"stable_video_diffusion",
|
| 79 |
+
"deepfloyd_if",
|
| 80 |
+
"kandinsky",
|
| 81 |
+
"kandinsky2_2",
|
| 82 |
+
"text_to_video_synthesis",
|
| 83 |
+
"wuerstchen",
|
| 84 |
+
]
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
@contextmanager
|
| 88 |
+
def checkout_commit(repo: Repo, commit_id: str):
|
| 89 |
+
"""
|
| 90 |
+
Context manager that checks out a given commit when entered, but gets back to the reference it was at on exit.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
repo (`git.Repo`): A git repository (for instance the Transformers repo).
|
| 94 |
+
commit_id (`str`): The commit reference to checkout inside the context manager.
|
| 95 |
+
"""
|
| 96 |
+
current_head = repo.head.commit if repo.head.is_detached else repo.head.ref
|
| 97 |
+
|
| 98 |
+
try:
|
| 99 |
+
repo.git.checkout(commit_id)
|
| 100 |
+
yield
|
| 101 |
+
|
| 102 |
+
finally:
|
| 103 |
+
repo.git.checkout(current_head)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def clean_code(content: str) -> str:
|
| 107 |
+
"""
|
| 108 |
+
Remove docstrings, empty line or comments from some code (used to detect if a diff is real or only concern
|
| 109 |
+
comments or docstrings).
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
content (`str`): The code to clean
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
`str`: The cleaned code.
|
| 116 |
+
"""
|
| 117 |
+
# We need to deactivate autoformatting here to write escaped triple quotes (we cannot use real triple quotes or
|
| 118 |
+
# this would mess up the result if this function applied to this particular file).
|
| 119 |
+
# fmt: off
|
| 120 |
+
# Remove docstrings by splitting on triple " then triple ':
|
| 121 |
+
splits = content.split('\"\"\"')
|
| 122 |
+
content = "".join(splits[::2])
|
| 123 |
+
splits = content.split("\'\'\'")
|
| 124 |
+
# fmt: on
|
| 125 |
+
content = "".join(splits[::2])
|
| 126 |
+
|
| 127 |
+
# Remove empty lines and comments
|
| 128 |
+
lines_to_keep = []
|
| 129 |
+
for line in content.split("\n"):
|
| 130 |
+
# remove anything that is after a # sign.
|
| 131 |
+
line = re.sub("#.*$", "", line)
|
| 132 |
+
# remove white lines
|
| 133 |
+
if len(line) != 0 and not line.isspace():
|
| 134 |
+
lines_to_keep.append(line)
|
| 135 |
+
return "\n".join(lines_to_keep)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def keep_doc_examples_only(content: str) -> str:
|
| 139 |
+
"""
|
| 140 |
+
Remove everything from the code content except the doc examples (used to determined if a diff should trigger doc
|
| 141 |
+
tests or not).
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
content (`str`): The code to clean
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
`str`: The cleaned code.
|
| 148 |
+
"""
|
| 149 |
+
# Keep doc examples only by splitting on triple "`"
|
| 150 |
+
splits = content.split("```")
|
| 151 |
+
# Add leading and trailing "```" so the navigation is easier when compared to the original input `content`
|
| 152 |
+
content = "```" + "```".join(splits[1::2]) + "```"
|
| 153 |
+
|
| 154 |
+
# Remove empty lines and comments
|
| 155 |
+
lines_to_keep = []
|
| 156 |
+
for line in content.split("\n"):
|
| 157 |
+
# remove anything that is after a # sign.
|
| 158 |
+
line = re.sub("#.*$", "", line)
|
| 159 |
+
# remove white lines
|
| 160 |
+
if len(line) != 0 and not line.isspace():
|
| 161 |
+
lines_to_keep.append(line)
|
| 162 |
+
return "\n".join(lines_to_keep)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def get_all_tests() -> List[str]:
|
| 166 |
+
"""
|
| 167 |
+
Walks the `tests` folder to return a list of files/subfolders. This is used to split the tests to run when using
|
| 168 |
+
parallelism. The split is:
|
| 169 |
+
|
| 170 |
+
- folders under `tests`: (`tokenization`, `pipelines`, etc) except the subfolder `models` is excluded.
|
| 171 |
+
- folders under `tests/models`: `bert`, `gpt2`, etc.
|
| 172 |
+
- test files under `tests`: `test_modeling_common.py`, `test_tokenization_common.py`, etc.
|
| 173 |
+
"""
|
| 174 |
+
|
| 175 |
+
# test folders/files directly under `tests` folder
|
| 176 |
+
tests = os.listdir(PATH_TO_TESTS)
|
| 177 |
+
tests = [f"tests/{f}" for f in tests if "__pycache__" not in f]
|
| 178 |
+
tests = sorted([f for f in tests if (PATH_TO_REPO / f).is_dir() or f.startswith("tests/test_")])
|
| 179 |
+
|
| 180 |
+
return tests
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def diff_is_docstring_only(repo: Repo, branching_point: str, filename: str) -> bool:
|
| 184 |
+
"""
|
| 185 |
+
Check if the diff is only in docstrings (or comments and whitespace) in a filename.
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
repo (`git.Repo`): A git repository (for instance the Transformers repo).
|
| 189 |
+
branching_point (`str`): The commit reference of where to compare for the diff.
|
| 190 |
+
filename (`str`): The filename where we want to know if the diff isonly in docstrings/comments.
|
| 191 |
+
|
| 192 |
+
Returns:
|
| 193 |
+
`bool`: Whether the diff is docstring/comments only or not.
|
| 194 |
+
"""
|
| 195 |
+
folder = Path(repo.working_dir)
|
| 196 |
+
with checkout_commit(repo, branching_point):
|
| 197 |
+
with open(folder / filename, "r", encoding="utf-8") as f:
|
| 198 |
+
old_content = f.read()
|
| 199 |
+
|
| 200 |
+
with open(folder / filename, "r", encoding="utf-8") as f:
|
| 201 |
+
new_content = f.read()
|
| 202 |
+
|
| 203 |
+
old_content_clean = clean_code(old_content)
|
| 204 |
+
new_content_clean = clean_code(new_content)
|
| 205 |
+
|
| 206 |
+
return old_content_clean == new_content_clean
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def diff_contains_doc_examples(repo: Repo, branching_point: str, filename: str) -> bool:
|
| 210 |
+
"""
|
| 211 |
+
Check if the diff is only in code examples of the doc in a filename.
|
| 212 |
+
|
| 213 |
+
Args:
|
| 214 |
+
repo (`git.Repo`): A git repository (for instance the Transformers repo).
|
| 215 |
+
branching_point (`str`): The commit reference of where to compare for the diff.
|
| 216 |
+
filename (`str`): The filename where we want to know if the diff is only in codes examples.
|
| 217 |
+
|
| 218 |
+
Returns:
|
| 219 |
+
`bool`: Whether the diff is only in code examples of the doc or not.
|
| 220 |
+
"""
|
| 221 |
+
folder = Path(repo.working_dir)
|
| 222 |
+
with checkout_commit(repo, branching_point):
|
| 223 |
+
with open(folder / filename, "r", encoding="utf-8") as f:
|
| 224 |
+
old_content = f.read()
|
| 225 |
+
|
| 226 |
+
with open(folder / filename, "r", encoding="utf-8") as f:
|
| 227 |
+
new_content = f.read()
|
| 228 |
+
|
| 229 |
+
old_content_clean = keep_doc_examples_only(old_content)
|
| 230 |
+
new_content_clean = keep_doc_examples_only(new_content)
|
| 231 |
+
|
| 232 |
+
return old_content_clean != new_content_clean
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def get_diff(repo: Repo, base_commit: str, commits: List[str]) -> List[str]:
|
| 236 |
+
"""
|
| 237 |
+
Get the diff between a base commit and one or several commits.
|
| 238 |
+
|
| 239 |
+
Args:
|
| 240 |
+
repo (`git.Repo`):
|
| 241 |
+
A git repository (for instance the Transformers repo).
|
| 242 |
+
base_commit (`str`):
|
| 243 |
+
The commit reference of where to compare for the diff. This is the current commit, not the branching point!
|
| 244 |
+
commits (`List[str]`):
|
| 245 |
+
The list of commits with which to compare the repo at `base_commit` (so the branching point).
|
| 246 |
+
|
| 247 |
+
Returns:
|
| 248 |
+
`List[str]`: The list of Python files with a diff (files added, renamed or deleted are always returned, files
|
| 249 |
+
modified are returned if the diff in the file is not only in docstrings or comments, see
|
| 250 |
+
`diff_is_docstring_only`).
|
| 251 |
+
"""
|
| 252 |
+
print("\n### DIFF ###\n")
|
| 253 |
+
code_diff = []
|
| 254 |
+
for commit in commits:
|
| 255 |
+
for diff_obj in commit.diff(base_commit):
|
| 256 |
+
# We always add new python files
|
| 257 |
+
if diff_obj.change_type == "A" and diff_obj.b_path.endswith(".py"):
|
| 258 |
+
code_diff.append(diff_obj.b_path)
|
| 259 |
+
# We check that deleted python files won't break corresponding tests.
|
| 260 |
+
elif diff_obj.change_type == "D" and diff_obj.a_path.endswith(".py"):
|
| 261 |
+
code_diff.append(diff_obj.a_path)
|
| 262 |
+
# Now for modified files
|
| 263 |
+
elif diff_obj.change_type in ["M", "R"] and diff_obj.b_path.endswith(".py"):
|
| 264 |
+
# In case of renames, we'll look at the tests using both the old and new name.
|
| 265 |
+
if diff_obj.a_path != diff_obj.b_path:
|
| 266 |
+
code_diff.extend([diff_obj.a_path, diff_obj.b_path])
|
| 267 |
+
else:
|
| 268 |
+
# Otherwise, we check modifications are in code and not docstrings.
|
| 269 |
+
if diff_is_docstring_only(repo, commit, diff_obj.b_path):
|
| 270 |
+
print(f"Ignoring diff in {diff_obj.b_path} as it only concerns docstrings or comments.")
|
| 271 |
+
else:
|
| 272 |
+
code_diff.append(diff_obj.a_path)
|
| 273 |
+
|
| 274 |
+
return code_diff
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def get_modified_python_files(diff_with_last_commit: bool = False) -> List[str]:
|
| 278 |
+
"""
|
| 279 |
+
Return a list of python files that have been modified between:
|
| 280 |
+
|
| 281 |
+
- the current head and the main branch if `diff_with_last_commit=False` (default)
|
| 282 |
+
- the current head and its parent commit otherwise.
|
| 283 |
+
|
| 284 |
+
Returns:
|
| 285 |
+
`List[str]`: The list of Python files with a diff (files added, renamed or deleted are always returned, files
|
| 286 |
+
modified are returned if the diff in the file is not only in docstrings or comments, see
|
| 287 |
+
`diff_is_docstring_only`).
|
| 288 |
+
"""
|
| 289 |
+
repo = Repo(PATH_TO_REPO)
|
| 290 |
+
|
| 291 |
+
if not diff_with_last_commit:
|
| 292 |
+
# Need to fetch refs for main using remotes when running with github actions.
|
| 293 |
+
upstream_main = repo.remotes.origin.refs.main
|
| 294 |
+
|
| 295 |
+
print(f"main is at {upstream_main.commit}")
|
| 296 |
+
print(f"Current head is at {repo.head.commit}")
|
| 297 |
+
|
| 298 |
+
branching_commits = repo.merge_base(upstream_main, repo.head)
|
| 299 |
+
for commit in branching_commits:
|
| 300 |
+
print(f"Branching commit: {commit}")
|
| 301 |
+
return get_diff(repo, repo.head.commit, branching_commits)
|
| 302 |
+
else:
|
| 303 |
+
print(f"main is at {repo.head.commit}")
|
| 304 |
+
parent_commits = repo.head.commit.parents
|
| 305 |
+
for commit in parent_commits:
|
| 306 |
+
print(f"Parent commit: {commit}")
|
| 307 |
+
return get_diff(repo, repo.head.commit, parent_commits)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def get_diff_for_doctesting(repo: Repo, base_commit: str, commits: List[str]) -> List[str]:
|
| 311 |
+
"""
|
| 312 |
+
Get the diff in doc examples between a base commit and one or several commits.
|
| 313 |
+
|
| 314 |
+
Args:
|
| 315 |
+
repo (`git.Repo`):
|
| 316 |
+
A git repository (for instance the Transformers repo).
|
| 317 |
+
base_commit (`str`):
|
| 318 |
+
The commit reference of where to compare for the diff. This is the current commit, not the branching point!
|
| 319 |
+
commits (`List[str]`):
|
| 320 |
+
The list of commits with which to compare the repo at `base_commit` (so the branching point).
|
| 321 |
+
|
| 322 |
+
Returns:
|
| 323 |
+
`List[str]`: The list of Python and Markdown files with a diff (files added or renamed are always returned, files
|
| 324 |
+
modified are returned if the diff in the file is only in doctest examples).
|
| 325 |
+
"""
|
| 326 |
+
print("\n### DIFF ###\n")
|
| 327 |
+
code_diff = []
|
| 328 |
+
for commit in commits:
|
| 329 |
+
for diff_obj in commit.diff(base_commit):
|
| 330 |
+
# We only consider Python files and doc files.
|
| 331 |
+
if not diff_obj.b_path.endswith(".py") and not diff_obj.b_path.endswith(".md"):
|
| 332 |
+
continue
|
| 333 |
+
# We always add new python/md files
|
| 334 |
+
if diff_obj.change_type in ["A"]:
|
| 335 |
+
code_diff.append(diff_obj.b_path)
|
| 336 |
+
# Now for modified files
|
| 337 |
+
elif diff_obj.change_type in ["M", "R"]:
|
| 338 |
+
# In case of renames, we'll look at the tests using both the old and new name.
|
| 339 |
+
if diff_obj.a_path != diff_obj.b_path:
|
| 340 |
+
code_diff.extend([diff_obj.a_path, diff_obj.b_path])
|
| 341 |
+
else:
|
| 342 |
+
# Otherwise, we check modifications contain some doc example(s).
|
| 343 |
+
if diff_contains_doc_examples(repo, commit, diff_obj.b_path):
|
| 344 |
+
code_diff.append(diff_obj.a_path)
|
| 345 |
+
else:
|
| 346 |
+
print(f"Ignoring diff in {diff_obj.b_path} as it doesn't contain any doc example.")
|
| 347 |
+
|
| 348 |
+
return code_diff
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def get_all_doctest_files() -> List[str]:
|
| 352 |
+
"""
|
| 353 |
+
Return the complete list of python and Markdown files on which we run doctest.
|
| 354 |
+
|
| 355 |
+
At this moment, we restrict this to only take files from `src/` or `docs/source/en/` that are not in `utils/not_doctested.txt`.
|
| 356 |
+
|
| 357 |
+
Returns:
|
| 358 |
+
`List[str]`: The complete list of Python and Markdown files on which we run doctest.
|
| 359 |
+
"""
|
| 360 |
+
py_files = [str(x.relative_to(PATH_TO_REPO)) for x in PATH_TO_REPO.glob("**/*.py")]
|
| 361 |
+
md_files = [str(x.relative_to(PATH_TO_REPO)) for x in PATH_TO_REPO.glob("**/*.md")]
|
| 362 |
+
test_files_to_run = py_files + md_files
|
| 363 |
+
|
| 364 |
+
# only include files in `src` or `docs/source/en/`
|
| 365 |
+
test_files_to_run = [x for x in test_files_to_run if x.startswith(("src/", "docs/source/en/"))]
|
| 366 |
+
# not include init files
|
| 367 |
+
test_files_to_run = [x for x in test_files_to_run if not x.endswith(("__init__.py",))]
|
| 368 |
+
|
| 369 |
+
# These are files not doctested yet.
|
| 370 |
+
with open("utils/not_doctested.txt") as fp:
|
| 371 |
+
not_doctested = {x.split(" ")[0] for x in fp.read().strip().split("\n")}
|
| 372 |
+
|
| 373 |
+
# So far we don't have 100% coverage for doctest. This line will be removed once we achieve 100%.
|
| 374 |
+
test_files_to_run = [x for x in test_files_to_run if x not in not_doctested]
|
| 375 |
+
|
| 376 |
+
return sorted(test_files_to_run)
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
def get_new_doctest_files(repo, base_commit, branching_commit) -> List[str]:
|
| 380 |
+
"""
|
| 381 |
+
Get the list of files that were removed from "utils/not_doctested.txt", between `base_commit` and
|
| 382 |
+
`branching_commit`.
|
| 383 |
+
|
| 384 |
+
Returns:
|
| 385 |
+
`List[str]`: List of files that were removed from "utils/not_doctested.txt".
|
| 386 |
+
"""
|
| 387 |
+
for diff_obj in branching_commit.diff(base_commit):
|
| 388 |
+
# Ignores all but the "utils/not_doctested.txt" file.
|
| 389 |
+
if diff_obj.a_path != "utils/not_doctested.txt":
|
| 390 |
+
continue
|
| 391 |
+
# Loads the two versions
|
| 392 |
+
folder = Path(repo.working_dir)
|
| 393 |
+
with checkout_commit(repo, branching_commit):
|
| 394 |
+
with open(folder / "utils/not_doctested.txt", "r", encoding="utf-8") as f:
|
| 395 |
+
old_content = f.read()
|
| 396 |
+
with open(folder / "utils/not_doctested.txt", "r", encoding="utf-8") as f:
|
| 397 |
+
new_content = f.read()
|
| 398 |
+
# Compute the removed lines and return them
|
| 399 |
+
removed_content = {x.split(" ")[0] for x in old_content.split("\n")} - {
|
| 400 |
+
x.split(" ")[0] for x in new_content.split("\n")
|
| 401 |
+
}
|
| 402 |
+
return sorted(removed_content)
|
| 403 |
+
return []
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
def get_doctest_files(diff_with_last_commit: bool = False) -> List[str]:
|
| 407 |
+
"""
|
| 408 |
+
Return a list of python and Markdown files where doc example have been modified between:
|
| 409 |
+
|
| 410 |
+
- the current head and the main branch if `diff_with_last_commit=False` (default)
|
| 411 |
+
- the current head and its parent commit otherwise.
|
| 412 |
+
|
| 413 |
+
Returns:
|
| 414 |
+
`List[str]`: The list of Python and Markdown files with a diff (files added or renamed are always returned, files
|
| 415 |
+
modified are returned if the diff in the file is only in doctest examples).
|
| 416 |
+
"""
|
| 417 |
+
repo = Repo(PATH_TO_REPO)
|
| 418 |
+
|
| 419 |
+
test_files_to_run = [] # noqa
|
| 420 |
+
if not diff_with_last_commit:
|
| 421 |
+
upstream_main = repo.remotes.origin.refs.main
|
| 422 |
+
print(f"main is at {upstream_main.commit}")
|
| 423 |
+
print(f"Current head is at {repo.head.commit}")
|
| 424 |
+
|
| 425 |
+
branching_commits = repo.merge_base(upstream_main, repo.head)
|
| 426 |
+
for commit in branching_commits:
|
| 427 |
+
print(f"Branching commit: {commit}")
|
| 428 |
+
test_files_to_run = get_diff_for_doctesting(repo, repo.head.commit, branching_commits)
|
| 429 |
+
else:
|
| 430 |
+
print(f"main is at {repo.head.commit}")
|
| 431 |
+
parent_commits = repo.head.commit.parents
|
| 432 |
+
for commit in parent_commits:
|
| 433 |
+
print(f"Parent commit: {commit}")
|
| 434 |
+
test_files_to_run = get_diff_for_doctesting(repo, repo.head.commit, parent_commits)
|
| 435 |
+
|
| 436 |
+
all_test_files_to_run = get_all_doctest_files()
|
| 437 |
+
|
| 438 |
+
# Add to the test files to run any removed entry from "utils/not_doctested.txt".
|
| 439 |
+
new_test_files = get_new_doctest_files(repo, repo.head.commit, upstream_main.commit)
|
| 440 |
+
test_files_to_run = list(set(test_files_to_run + new_test_files))
|
| 441 |
+
|
| 442 |
+
# Do not run slow doctest tests on CircleCI
|
| 443 |
+
with open("utils/slow_documentation_tests.txt") as fp:
|
| 444 |
+
slow_documentation_tests = set(fp.read().strip().split("\n"))
|
| 445 |
+
test_files_to_run = [
|
| 446 |
+
x for x in test_files_to_run if x in all_test_files_to_run and x not in slow_documentation_tests
|
| 447 |
+
]
|
| 448 |
+
|
| 449 |
+
# Make sure we did not end up with a test file that was removed
|
| 450 |
+
test_files_to_run = [f for f in test_files_to_run if (PATH_TO_REPO / f).exists()]
|
| 451 |
+
|
| 452 |
+
return sorted(test_files_to_run)
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
# (:?^|\n) -> Non-catching group for the beginning of the doc or a new line.
|
| 456 |
+
# \s*from\s+(\.+\S+)\s+import\s+([^\n]+) -> Line only contains from .xxx import yyy and we catch .xxx and yyy
|
| 457 |
+
# (?=\n) -> Look-ahead to a new line. We can't just put \n here or using find_all on this re will only catch every
|
| 458 |
+
# other import.
|
| 459 |
+
_re_single_line_relative_imports = re.compile(r"(?:^|\n)\s*from\s+(\.+\S+)\s+import\s+([^\n]+)(?=\n)")
|
| 460 |
+
# (:?^|\n) -> Non-catching group for the beginning of the doc or a new line.
|
| 461 |
+
# \s*from\s+(\.+\S+)\s+import\s+\(([^\)]+)\) -> Line continues with from .xxx import (yyy) and we catch .xxx and yyy
|
| 462 |
+
# yyy will take multiple lines otherwise there wouldn't be parenthesis.
|
| 463 |
+
_re_multi_line_relative_imports = re.compile(r"(?:^|\n)\s*from\s+(\.+\S+)\s+import\s+\(([^\)]+)\)")
|
| 464 |
+
# (:?^|\n) -> Non-catching group for the beginning of the doc or a new line.
|
| 465 |
+
# \s*from\s+transformers(\S*)\s+import\s+([^\n]+) -> Line only contains from transformers.xxx import yyy and we catch
|
| 466 |
+
# .xxx and yyy
|
| 467 |
+
# (?=\n) -> Look-ahead to a new line. We can't just put \n here or using find_all on this re will only catch every
|
| 468 |
+
# other import.
|
| 469 |
+
_re_single_line_direct_imports = re.compile(r"(?:^|\n)\s*from\s+diffusers(\S*)\s+import\s+([^\n]+)(?=\n)")
|
| 470 |
+
# (:?^|\n) -> Non-catching group for the beginning of the doc or a new line.
|
| 471 |
+
# \s*from\s+transformers(\S*)\s+import\s+\(([^\)]+)\) -> Line continues with from transformers.xxx import (yyy) and we
|
| 472 |
+
# catch .xxx and yyy. yyy will take multiple lines otherwise there wouldn't be parenthesis.
|
| 473 |
+
_re_multi_line_direct_imports = re.compile(r"(?:^|\n)\s*from\s+diffusers(\S*)\s+import\s+\(([^\)]+)\)")
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
def extract_imports(module_fname: str, cache: Dict[str, List[str]] = None) -> List[str]:
|
| 477 |
+
"""
|
| 478 |
+
Get the imports a given module makes.
|
| 479 |
+
|
| 480 |
+
Args:
|
| 481 |
+
module_fname (`str`):
|
| 482 |
+
The name of the file of the module where we want to look at the imports (given relative to the root of
|
| 483 |
+
the repo).
|
| 484 |
+
cache (Dictionary `str` to `List[str]`, *optional*):
|
| 485 |
+
To speed up this function if it was previously called on `module_fname`, the cache of all previously
|
| 486 |
+
computed results.
|
| 487 |
+
|
| 488 |
+
Returns:
|
| 489 |
+
`List[str]`: The list of module filenames imported in the input `module_fname` (a submodule we import from that
|
| 490 |
+
is a subfolder will give its init file).
|
| 491 |
+
"""
|
| 492 |
+
if cache is not None and module_fname in cache:
|
| 493 |
+
return cache[module_fname]
|
| 494 |
+
|
| 495 |
+
with open(PATH_TO_REPO / module_fname, "r", encoding="utf-8") as f:
|
| 496 |
+
content = f.read()
|
| 497 |
+
|
| 498 |
+
# Filter out all docstrings to not get imports in code examples. As before we need to deactivate formatting to
|
| 499 |
+
# keep this as escaped quotes and avoid this function failing on this file.
|
| 500 |
+
# fmt: off
|
| 501 |
+
splits = content.split('\"\"\"')
|
| 502 |
+
# fmt: on
|
| 503 |
+
content = "".join(splits[::2])
|
| 504 |
+
|
| 505 |
+
module_parts = str(module_fname).split(os.path.sep)
|
| 506 |
+
imported_modules = []
|
| 507 |
+
|
| 508 |
+
# Let's start with relative imports
|
| 509 |
+
relative_imports = _re_single_line_relative_imports.findall(content)
|
| 510 |
+
relative_imports = [
|
| 511 |
+
(mod, imp) for mod, imp in relative_imports if "# tests_ignore" not in imp and imp.strip() != "("
|
| 512 |
+
]
|
| 513 |
+
multiline_relative_imports = _re_multi_line_relative_imports.findall(content)
|
| 514 |
+
relative_imports += [(mod, imp) for mod, imp in multiline_relative_imports if "# tests_ignore" not in imp]
|
| 515 |
+
|
| 516 |
+
# We need to remove parts of the module name depending on the depth of the relative imports.
|
| 517 |
+
for module, imports in relative_imports:
|
| 518 |
+
level = 0
|
| 519 |
+
while module.startswith("."):
|
| 520 |
+
module = module[1:]
|
| 521 |
+
level += 1
|
| 522 |
+
|
| 523 |
+
if len(module) > 0:
|
| 524 |
+
dep_parts = module_parts[: len(module_parts) - level] + module.split(".")
|
| 525 |
+
else:
|
| 526 |
+
dep_parts = module_parts[: len(module_parts) - level]
|
| 527 |
+
imported_module = os.path.sep.join(dep_parts)
|
| 528 |
+
imported_modules.append((imported_module, [imp.strip() for imp in imports.split(",")]))
|
| 529 |
+
|
| 530 |
+
# Let's continue with direct imports
|
| 531 |
+
direct_imports = _re_single_line_direct_imports.findall(content)
|
| 532 |
+
direct_imports = [(mod, imp) for mod, imp in direct_imports if "# tests_ignore" not in imp and imp.strip() != "("]
|
| 533 |
+
multiline_direct_imports = _re_multi_line_direct_imports.findall(content)
|
| 534 |
+
direct_imports += [(mod, imp) for mod, imp in multiline_direct_imports if "# tests_ignore" not in imp]
|
| 535 |
+
|
| 536 |
+
# We need to find the relative path of those imports.
|
| 537 |
+
for module, imports in direct_imports:
|
| 538 |
+
import_parts = module.split(".")[1:] # ignore the name of the repo since we add it below.
|
| 539 |
+
dep_parts = ["src", "diffusers"] + import_parts
|
| 540 |
+
imported_module = os.path.sep.join(dep_parts)
|
| 541 |
+
imported_modules.append((imported_module, [imp.strip() for imp in imports.split(",")]))
|
| 542 |
+
|
| 543 |
+
result = []
|
| 544 |
+
# Double check we get proper modules (either a python file or a folder with an init).
|
| 545 |
+
for module_file, imports in imported_modules:
|
| 546 |
+
if (PATH_TO_REPO / f"{module_file}.py").is_file():
|
| 547 |
+
module_file = f"{module_file}.py"
|
| 548 |
+
elif (PATH_TO_REPO / module_file).is_dir() and (PATH_TO_REPO / module_file / "__init__.py").is_file():
|
| 549 |
+
module_file = os.path.sep.join([module_file, "__init__.py"])
|
| 550 |
+
imports = [imp for imp in imports if len(imp) > 0 and re.match("^[A-Za-z0-9_]*$", imp)]
|
| 551 |
+
if len(imports) > 0:
|
| 552 |
+
result.append((module_file, imports))
|
| 553 |
+
|
| 554 |
+
if cache is not None:
|
| 555 |
+
cache[module_fname] = result
|
| 556 |
+
|
| 557 |
+
return result
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
def get_module_dependencies(module_fname: str, cache: Dict[str, List[str]] = None) -> List[str]:
|
| 561 |
+
"""
|
| 562 |
+
Refines the result of `extract_imports` to remove subfolders and get a proper list of module filenames: if a file
|
| 563 |
+
as an import `from utils import Foo, Bar`, with `utils` being a subfolder containing many files, this will traverse
|
| 564 |
+
the `utils` init file to check where those dependencies come from: for instance the files utils/foo.py and utils/bar.py.
|
| 565 |
+
|
| 566 |
+
Warning: This presupposes that all intermediate inits are properly built (with imports from the respective
|
| 567 |
+
submodules) and work better if objects are defined in submodules and not the intermediate init (otherwise the
|
| 568 |
+
intermediate init is added, and inits usually have a lot of dependencies).
|
| 569 |
+
|
| 570 |
+
Args:
|
| 571 |
+
module_fname (`str`):
|
| 572 |
+
The name of the file of the module where we want to look at the imports (given relative to the root of
|
| 573 |
+
the repo).
|
| 574 |
+
cache (Dictionary `str` to `List[str]`, *optional*):
|
| 575 |
+
To speed up this function if it was previously called on `module_fname`, the cache of all previously
|
| 576 |
+
computed results.
|
| 577 |
+
|
| 578 |
+
Returns:
|
| 579 |
+
`List[str]`: The list of module filenames imported in the input `module_fname` (with submodule imports refined).
|
| 580 |
+
"""
|
| 581 |
+
dependencies = []
|
| 582 |
+
imported_modules = extract_imports(module_fname, cache=cache)
|
| 583 |
+
# The while loop is to recursively traverse all inits we may encounter: we will add things as we go.
|
| 584 |
+
while len(imported_modules) > 0:
|
| 585 |
+
new_modules = []
|
| 586 |
+
for module, imports in imported_modules:
|
| 587 |
+
# If we end up in an __init__ we are often not actually importing from this init (except in the case where
|
| 588 |
+
# the object is fully defined in the __init__)
|
| 589 |
+
if module.endswith("__init__.py"):
|
| 590 |
+
# So we get the imports from that init then try to find where our objects come from.
|
| 591 |
+
new_imported_modules = extract_imports(module, cache=cache)
|
| 592 |
+
for new_module, new_imports in new_imported_modules:
|
| 593 |
+
if any(i in new_imports for i in imports):
|
| 594 |
+
if new_module not in dependencies:
|
| 595 |
+
new_modules.append((new_module, [i for i in new_imports if i in imports]))
|
| 596 |
+
imports = [i for i in imports if i not in new_imports]
|
| 597 |
+
if len(imports) > 0:
|
| 598 |
+
# If there are any objects lefts, they may be a submodule
|
| 599 |
+
path_to_module = PATH_TO_REPO / module.replace("__init__.py", "")
|
| 600 |
+
dependencies.extend(
|
| 601 |
+
[
|
| 602 |
+
os.path.join(module.replace("__init__.py", ""), f"{i}.py")
|
| 603 |
+
for i in imports
|
| 604 |
+
if (path_to_module / f"{i}.py").is_file()
|
| 605 |
+
]
|
| 606 |
+
)
|
| 607 |
+
imports = [i for i in imports if not (path_to_module / f"{i}.py").is_file()]
|
| 608 |
+
if len(imports) > 0:
|
| 609 |
+
# Then if there are still objects left, they are fully defined in the init, so we keep it as a
|
| 610 |
+
# dependency.
|
| 611 |
+
dependencies.append(module)
|
| 612 |
+
else:
|
| 613 |
+
dependencies.append(module)
|
| 614 |
+
|
| 615 |
+
imported_modules = new_modules
|
| 616 |
+
|
| 617 |
+
return dependencies
|
| 618 |
+
|
| 619 |
+
|
| 620 |
+
def create_reverse_dependency_tree() -> List[Tuple[str, str]]:
|
| 621 |
+
"""
|
| 622 |
+
Create a list of all edges (a, b) which mean that modifying a impacts b with a going over all module and test files.
|
| 623 |
+
"""
|
| 624 |
+
cache = {}
|
| 625 |
+
all_modules = list(PATH_TO_DIFFUSERS.glob("**/*.py")) + list(PATH_TO_TESTS.glob("**/*.py"))
|
| 626 |
+
all_modules = [str(mod.relative_to(PATH_TO_REPO)) for mod in all_modules]
|
| 627 |
+
edges = [(dep, mod) for mod in all_modules for dep in get_module_dependencies(mod, cache=cache)]
|
| 628 |
+
|
| 629 |
+
return list(set(edges))
|
| 630 |
+
|
| 631 |
+
|
| 632 |
+
def get_tree_starting_at(module: str, edges: List[Tuple[str, str]]) -> List[Union[str, List[str]]]:
|
| 633 |
+
"""
|
| 634 |
+
Returns the tree starting at a given module following all edges.
|
| 635 |
+
|
| 636 |
+
Args:
|
| 637 |
+
module (`str`): The module that will be the root of the subtree we want.
|
| 638 |
+
edges (`List[Tuple[str, str]]`): The list of all edges of the tree.
|
| 639 |
+
|
| 640 |
+
Returns:
|
| 641 |
+
`List[Union[str, List[str]]]`: The tree to print in the following format: [module, [list of edges
|
| 642 |
+
starting at module], [list of edges starting at the preceding level], ...]
|
| 643 |
+
"""
|
| 644 |
+
vertices_seen = [module]
|
| 645 |
+
new_edges = [edge for edge in edges if edge[0] == module and edge[1] != module and "__init__.py" not in edge[1]]
|
| 646 |
+
tree = [module]
|
| 647 |
+
while len(new_edges) > 0:
|
| 648 |
+
tree.append(new_edges)
|
| 649 |
+
final_vertices = list({edge[1] for edge in new_edges})
|
| 650 |
+
vertices_seen.extend(final_vertices)
|
| 651 |
+
new_edges = [
|
| 652 |
+
edge
|
| 653 |
+
for edge in edges
|
| 654 |
+
if edge[0] in final_vertices and edge[1] not in vertices_seen and "__init__.py" not in edge[1]
|
| 655 |
+
]
|
| 656 |
+
|
| 657 |
+
return tree
|
| 658 |
+
|
| 659 |
+
|
| 660 |
+
def print_tree_deps_of(module, all_edges=None):
|
| 661 |
+
"""
|
| 662 |
+
Prints the tree of modules depending on a given module.
|
| 663 |
+
|
| 664 |
+
Args:
|
| 665 |
+
module (`str`): The module that will be the root of the subtree we want.
|
| 666 |
+
all_edges (`List[Tuple[str, str]]`, *optional*):
|
| 667 |
+
The list of all edges of the tree. Will be set to `create_reverse_dependency_tree()` if not passed.
|
| 668 |
+
"""
|
| 669 |
+
if all_edges is None:
|
| 670 |
+
all_edges = create_reverse_dependency_tree()
|
| 671 |
+
tree = get_tree_starting_at(module, all_edges)
|
| 672 |
+
|
| 673 |
+
# The list of lines is a list of tuples (line_to_be_printed, module)
|
| 674 |
+
# Keeping the modules lets us know where to insert each new lines in the list.
|
| 675 |
+
lines = [(tree[0], tree[0])]
|
| 676 |
+
for index in range(1, len(tree)):
|
| 677 |
+
edges = tree[index]
|
| 678 |
+
start_edges = {edge[0] for edge in edges}
|
| 679 |
+
|
| 680 |
+
for start in start_edges:
|
| 681 |
+
end_edges = {edge[1] for edge in edges if edge[0] == start}
|
| 682 |
+
# We will insert all those edges just after the line showing start.
|
| 683 |
+
pos = 0
|
| 684 |
+
while lines[pos][1] != start:
|
| 685 |
+
pos += 1
|
| 686 |
+
lines = lines[: pos + 1] + [(" " * (2 * index) + end, end) for end in end_edges] + lines[pos + 1 :]
|
| 687 |
+
|
| 688 |
+
for line in lines:
|
| 689 |
+
# We don't print the refs that where just here to help build lines.
|
| 690 |
+
print(line[0])
|
| 691 |
+
|
| 692 |
+
|
| 693 |
+
def init_test_examples_dependencies() -> Tuple[Dict[str, List[str]], List[str]]:
|
| 694 |
+
"""
|
| 695 |
+
The test examples do not import from the examples (which are just scripts, not modules) so we need some extra
|
| 696 |
+
care initializing the dependency map, which is the goal of this function. It initializes the dependency map for
|
| 697 |
+
example files by linking each example to the example test file for the example framework.
|
| 698 |
+
|
| 699 |
+
Returns:
|
| 700 |
+
`Tuple[Dict[str, List[str]], List[str]]`: A tuple with two elements: the initialized dependency map which is a
|
| 701 |
+
dict test example file to list of example files potentially tested by that test file, and the list of all
|
| 702 |
+
example files (to avoid recomputing it later).
|
| 703 |
+
"""
|
| 704 |
+
test_example_deps = {}
|
| 705 |
+
all_examples = []
|
| 706 |
+
for framework in ["flax", "pytorch", "tensorflow"]:
|
| 707 |
+
test_files = list((PATH_TO_EXAMPLES / framework).glob("test_*.py"))
|
| 708 |
+
all_examples.extend(test_files)
|
| 709 |
+
# Remove the files at the root of examples/framework since they are not proper examples (they are either utils
|
| 710 |
+
# or example test files).
|
| 711 |
+
examples = [
|
| 712 |
+
f for f in (PATH_TO_EXAMPLES / framework).glob("**/*.py") if f.parent != PATH_TO_EXAMPLES / framework
|
| 713 |
+
]
|
| 714 |
+
all_examples.extend(examples)
|
| 715 |
+
for test_file in test_files:
|
| 716 |
+
with open(test_file, "r", encoding="utf-8") as f:
|
| 717 |
+
content = f.read()
|
| 718 |
+
# Map all examples to the test files found in examples/framework.
|
| 719 |
+
test_example_deps[str(test_file.relative_to(PATH_TO_REPO))] = [
|
| 720 |
+
str(e.relative_to(PATH_TO_REPO)) for e in examples if e.name in content
|
| 721 |
+
]
|
| 722 |
+
# Also map the test files to themselves.
|
| 723 |
+
test_example_deps[str(test_file.relative_to(PATH_TO_REPO))].append(
|
| 724 |
+
str(test_file.relative_to(PATH_TO_REPO))
|
| 725 |
+
)
|
| 726 |
+
return test_example_deps, all_examples
|
| 727 |
+
|
| 728 |
+
|
| 729 |
+
def create_reverse_dependency_map() -> Dict[str, List[str]]:
|
| 730 |
+
"""
|
| 731 |
+
Create the dependency map from module/test filename to the list of modules/tests that depend on it recursively.
|
| 732 |
+
|
| 733 |
+
Returns:
|
| 734 |
+
`Dict[str, List[str]]`: The reverse dependency map as a dictionary mapping filenames to all the filenames
|
| 735 |
+
depending on it recursively. This way the tests impacted by a change in file A are the test files in the list
|
| 736 |
+
corresponding to key A in this result.
|
| 737 |
+
"""
|
| 738 |
+
cache = {}
|
| 739 |
+
# Start from the example deps init.
|
| 740 |
+
example_deps, examples = init_test_examples_dependencies()
|
| 741 |
+
# Add all modules and all tests to all examples
|
| 742 |
+
all_modules = list(PATH_TO_DIFFUSERS.glob("**/*.py")) + list(PATH_TO_TESTS.glob("**/*.py")) + examples
|
| 743 |
+
all_modules = [str(mod.relative_to(PATH_TO_REPO)) for mod in all_modules]
|
| 744 |
+
# Compute the direct dependencies of all modules.
|
| 745 |
+
direct_deps = {m: get_module_dependencies(m, cache=cache) for m in all_modules}
|
| 746 |
+
direct_deps.update(example_deps)
|
| 747 |
+
|
| 748 |
+
# This recurses the dependencies
|
| 749 |
+
something_changed = True
|
| 750 |
+
while something_changed:
|
| 751 |
+
something_changed = False
|
| 752 |
+
for m in all_modules:
|
| 753 |
+
for d in direct_deps[m]:
|
| 754 |
+
# We stop recursing at an init (cause we always end up in the main init and we don't want to add all
|
| 755 |
+
# files which the main init imports)
|
| 756 |
+
if d.endswith("__init__.py"):
|
| 757 |
+
continue
|
| 758 |
+
if d not in direct_deps:
|
| 759 |
+
raise ValueError(f"KeyError:{d}. From {m}")
|
| 760 |
+
new_deps = set(direct_deps[d]) - set(direct_deps[m])
|
| 761 |
+
if len(new_deps) > 0:
|
| 762 |
+
direct_deps[m].extend(list(new_deps))
|
| 763 |
+
something_changed = True
|
| 764 |
+
|
| 765 |
+
# Finally we can build the reverse map.
|
| 766 |
+
reverse_map = collections.defaultdict(list)
|
| 767 |
+
for m in all_modules:
|
| 768 |
+
for d in direct_deps[m]:
|
| 769 |
+
reverse_map[d].append(m)
|
| 770 |
+
|
| 771 |
+
# For inits, we don't do the reverse deps but the direct deps: if modifying an init, we want to make sure we test
|
| 772 |
+
# all the modules impacted by that init.
|
| 773 |
+
for m in [f for f in all_modules if f.endswith("__init__.py")]:
|
| 774 |
+
direct_deps = get_module_dependencies(m, cache=cache)
|
| 775 |
+
deps = sum([reverse_map[d] for d in direct_deps if not d.endswith("__init__.py")], direct_deps)
|
| 776 |
+
reverse_map[m] = list(set(deps) - {m})
|
| 777 |
+
|
| 778 |
+
return reverse_map
|
| 779 |
+
|
| 780 |
+
|
| 781 |
+
def create_module_to_test_map(reverse_map: Dict[str, List[str]] = None) -> Dict[str, List[str]]:
|
| 782 |
+
"""
|
| 783 |
+
Extract the tests from the reverse_dependency_map and potentially filters the model tests.
|
| 784 |
+
|
| 785 |
+
Args:
|
| 786 |
+
reverse_map (`Dict[str, List[str]]`, *optional*):
|
| 787 |
+
The reverse dependency map as created by `create_reverse_dependency_map`. Will default to the result of
|
| 788 |
+
that function if not provided.
|
| 789 |
+
filter_pipelines (`bool`, *optional*, defaults to `False`):
|
| 790 |
+
Whether or not to filter pipeline tests to only include core pipelines if a file impacts a lot of models.
|
| 791 |
+
|
| 792 |
+
Returns:
|
| 793 |
+
`Dict[str, List[str]]`: A dictionary that maps each file to the tests to execute if that file was modified.
|
| 794 |
+
"""
|
| 795 |
+
if reverse_map is None:
|
| 796 |
+
reverse_map = create_reverse_dependency_map()
|
| 797 |
+
|
| 798 |
+
# Utility that tells us if a given file is a test (taking test examples into account)
|
| 799 |
+
def is_test(fname):
|
| 800 |
+
if fname.startswith("tests"):
|
| 801 |
+
return True
|
| 802 |
+
if fname.startswith("examples") and fname.split(os.path.sep)[-1].startswith("test"):
|
| 803 |
+
return True
|
| 804 |
+
return False
|
| 805 |
+
|
| 806 |
+
# Build the test map
|
| 807 |
+
test_map = {module: [f for f in deps if is_test(f)] for module, deps in reverse_map.items()}
|
| 808 |
+
|
| 809 |
+
return test_map
|
| 810 |
+
|
| 811 |
+
|
| 812 |
+
def check_imports_all_exist():
|
| 813 |
+
"""
|
| 814 |
+
Isn't used per se by the test fetcher but might be used later as a quality check. Putting this here for now so the
|
| 815 |
+
code is not lost. This checks all imports in a given file do exist.
|
| 816 |
+
"""
|
| 817 |
+
cache = {}
|
| 818 |
+
all_modules = list(PATH_TO_DIFFUSERS.glob("**/*.py")) + list(PATH_TO_TESTS.glob("**/*.py"))
|
| 819 |
+
all_modules = [str(mod.relative_to(PATH_TO_REPO)) for mod in all_modules]
|
| 820 |
+
direct_deps = {m: get_module_dependencies(m, cache=cache) for m in all_modules}
|
| 821 |
+
|
| 822 |
+
for module, deps in direct_deps.items():
|
| 823 |
+
for dep in deps:
|
| 824 |
+
if not (PATH_TO_REPO / dep).is_file():
|
| 825 |
+
print(f"{module} has dependency on {dep} which does not exist.")
|
| 826 |
+
|
| 827 |
+
|
| 828 |
+
def _print_list(l) -> str:
|
| 829 |
+
"""
|
| 830 |
+
Pretty print a list of elements with one line per element and a - starting each line.
|
| 831 |
+
"""
|
| 832 |
+
return "\n".join([f"- {f}" for f in l])
|
| 833 |
+
|
| 834 |
+
|
| 835 |
+
def update_test_map_with_core_pipelines(json_output_file: str):
|
| 836 |
+
print(f"\n### ADD CORE PIPELINE TESTS ###\n{_print_list(IMPORTANT_PIPELINES)}")
|
| 837 |
+
with open(json_output_file, "rb") as fp:
|
| 838 |
+
test_map = json.load(fp)
|
| 839 |
+
|
| 840 |
+
# Add core pipelines as their own test group
|
| 841 |
+
test_map["core_pipelines"] = " ".join(
|
| 842 |
+
sorted([str(PATH_TO_TESTS / f"pipelines/{pipe}") for pipe in IMPORTANT_PIPELINES])
|
| 843 |
+
)
|
| 844 |
+
|
| 845 |
+
# If there are no existing pipeline tests save the map
|
| 846 |
+
if "pipelines" not in test_map:
|
| 847 |
+
with open(json_output_file, "w", encoding="UTF-8") as fp:
|
| 848 |
+
json.dump(test_map, fp, ensure_ascii=False)
|
| 849 |
+
|
| 850 |
+
pipeline_tests = test_map.pop("pipelines")
|
| 851 |
+
pipeline_tests = pipeline_tests.split(" ")
|
| 852 |
+
|
| 853 |
+
# Remove core pipeline tests from the fetched pipeline tests
|
| 854 |
+
updated_pipeline_tests = []
|
| 855 |
+
for pipe in pipeline_tests:
|
| 856 |
+
if pipe == "tests/pipelines" or Path(pipe).parts[2] in IMPORTANT_PIPELINES:
|
| 857 |
+
continue
|
| 858 |
+
updated_pipeline_tests.append(pipe)
|
| 859 |
+
|
| 860 |
+
if len(updated_pipeline_tests) > 0:
|
| 861 |
+
test_map["pipelines"] = " ".join(sorted(updated_pipeline_tests))
|
| 862 |
+
|
| 863 |
+
with open(json_output_file, "w", encoding="UTF-8") as fp:
|
| 864 |
+
json.dump(test_map, fp, ensure_ascii=False)
|
| 865 |
+
|
| 866 |
+
|
| 867 |
+
def create_json_map(test_files_to_run: List[str], json_output_file: Optional[str] = None):
|
| 868 |
+
"""
|
| 869 |
+
Creates a map from a list of tests to run to easily split them by category, when running parallelism of slow tests.
|
| 870 |
+
|
| 871 |
+
Args:
|
| 872 |
+
test_files_to_run (`List[str]`): The list of tests to run.
|
| 873 |
+
json_output_file (`str`): The path where to store the built json map.
|
| 874 |
+
"""
|
| 875 |
+
if json_output_file is None:
|
| 876 |
+
return
|
| 877 |
+
|
| 878 |
+
test_map = {}
|
| 879 |
+
for test_file in test_files_to_run:
|
| 880 |
+
# `test_file` is a path to a test folder/file, starting with `tests/`. For example,
|
| 881 |
+
# - `tests/models/bert/test_modeling_bert.py` or `tests/models/bert`
|
| 882 |
+
# - `tests/trainer/test_trainer.py` or `tests/trainer`
|
| 883 |
+
# - `tests/test_modeling_common.py`
|
| 884 |
+
names = test_file.split(os.path.sep)
|
| 885 |
+
module = names[1]
|
| 886 |
+
if module in MODULES_TO_IGNORE:
|
| 887 |
+
continue
|
| 888 |
+
|
| 889 |
+
if len(names) > 2 or not test_file.endswith(".py"):
|
| 890 |
+
# test folders under `tests` or python files under them
|
| 891 |
+
# take the part like tokenization, `pipeline`, etc. for other test categories
|
| 892 |
+
key = os.path.sep.join(names[1:2])
|
| 893 |
+
else:
|
| 894 |
+
# common test files directly under `tests/`
|
| 895 |
+
key = "common"
|
| 896 |
+
|
| 897 |
+
if key not in test_map:
|
| 898 |
+
test_map[key] = []
|
| 899 |
+
test_map[key].append(test_file)
|
| 900 |
+
|
| 901 |
+
# sort the keys & values
|
| 902 |
+
keys = sorted(test_map.keys())
|
| 903 |
+
test_map = {k: " ".join(sorted(test_map[k])) for k in keys}
|
| 904 |
+
|
| 905 |
+
with open(json_output_file, "w", encoding="UTF-8") as fp:
|
| 906 |
+
json.dump(test_map, fp, ensure_ascii=False)
|
| 907 |
+
|
| 908 |
+
|
| 909 |
+
def infer_tests_to_run(
|
| 910 |
+
output_file: str,
|
| 911 |
+
diff_with_last_commit: bool = False,
|
| 912 |
+
json_output_file: Optional[str] = None,
|
| 913 |
+
):
|
| 914 |
+
"""
|
| 915 |
+
The main function called by the test fetcher. Determines the tests to run from the diff.
|
| 916 |
+
|
| 917 |
+
Args:
|
| 918 |
+
output_file (`str`):
|
| 919 |
+
The path where to store the summary of the test fetcher analysis. Other files will be stored in the same
|
| 920 |
+
folder:
|
| 921 |
+
|
| 922 |
+
- examples_test_list.txt: The list of examples tests to run.
|
| 923 |
+
- test_repo_utils.txt: Will indicate if the repo utils tests should be run or not.
|
| 924 |
+
- doctest_list.txt: The list of doctests to run.
|
| 925 |
+
|
| 926 |
+
diff_with_last_commit (`bool`, *optional*, defaults to `False`):
|
| 927 |
+
Whether to analyze the diff with the last commit (for use on the main branch after a PR is merged) or with
|
| 928 |
+
the branching point from main (for use on each PR).
|
| 929 |
+
filter_models (`bool`, *optional*, defaults to `True`):
|
| 930 |
+
Whether or not to filter the tests to core models only, when a file modified results in a lot of model
|
| 931 |
+
tests.
|
| 932 |
+
json_output_file (`str`, *optional*):
|
| 933 |
+
The path where to store the json file mapping categories of tests to tests to run (used for parallelism or
|
| 934 |
+
the slow tests).
|
| 935 |
+
"""
|
| 936 |
+
modified_files = get_modified_python_files(diff_with_last_commit=diff_with_last_commit)
|
| 937 |
+
print(f"\n### MODIFIED FILES ###\n{_print_list(modified_files)}")
|
| 938 |
+
# Create the map that will give us all impacted modules.
|
| 939 |
+
reverse_map = create_reverse_dependency_map()
|
| 940 |
+
impacted_files = modified_files.copy()
|
| 941 |
+
for f in modified_files:
|
| 942 |
+
if f in reverse_map:
|
| 943 |
+
impacted_files.extend(reverse_map[f])
|
| 944 |
+
|
| 945 |
+
# Remove duplicates
|
| 946 |
+
impacted_files = sorted(set(impacted_files))
|
| 947 |
+
print(f"\n### IMPACTED FILES ###\n{_print_list(impacted_files)}")
|
| 948 |
+
|
| 949 |
+
# Grab the corresponding test files:
|
| 950 |
+
if any(x in modified_files for x in ["setup.py"]):
|
| 951 |
+
test_files_to_run = ["tests", "examples"]
|
| 952 |
+
|
| 953 |
+
# in order to trigger pipeline tests even if no code change at all
|
| 954 |
+
if "tests/utils/tiny_model_summary.json" in modified_files:
|
| 955 |
+
test_files_to_run = ["tests"]
|
| 956 |
+
any(f.split(os.path.sep)[0] == "utils" for f in modified_files)
|
| 957 |
+
else:
|
| 958 |
+
# All modified tests need to be run.
|
| 959 |
+
test_files_to_run = [
|
| 960 |
+
f for f in modified_files if f.startswith("tests") and f.split(os.path.sep)[-1].startswith("test")
|
| 961 |
+
]
|
| 962 |
+
# Then we grab the corresponding test files.
|
| 963 |
+
test_map = create_module_to_test_map(reverse_map=reverse_map)
|
| 964 |
+
for f in modified_files:
|
| 965 |
+
if f in test_map:
|
| 966 |
+
test_files_to_run.extend(test_map[f])
|
| 967 |
+
test_files_to_run = sorted(set(test_files_to_run))
|
| 968 |
+
# Make sure we did not end up with a test file that was removed
|
| 969 |
+
test_files_to_run = [f for f in test_files_to_run if (PATH_TO_REPO / f).exists()]
|
| 970 |
+
|
| 971 |
+
any(f.split(os.path.sep)[0] == "utils" for f in modified_files)
|
| 972 |
+
|
| 973 |
+
examples_tests_to_run = [f for f in test_files_to_run if f.startswith("examples")]
|
| 974 |
+
test_files_to_run = [f for f in test_files_to_run if not f.startswith("examples")]
|
| 975 |
+
print(f"\n### TEST TO RUN ###\n{_print_list(test_files_to_run)}")
|
| 976 |
+
if len(test_files_to_run) > 0:
|
| 977 |
+
with open(output_file, "w", encoding="utf-8") as f:
|
| 978 |
+
f.write(" ".join(test_files_to_run))
|
| 979 |
+
|
| 980 |
+
# Create a map that maps test categories to test files, i.e. `models/bert` -> [...test_modeling_bert.py, ...]
|
| 981 |
+
|
| 982 |
+
# Get all test directories (and some common test files) under `tests` and `tests/models` if `test_files_to_run`
|
| 983 |
+
# contains `tests` (i.e. when `setup.py` is changed).
|
| 984 |
+
if "tests" in test_files_to_run:
|
| 985 |
+
test_files_to_run = get_all_tests()
|
| 986 |
+
|
| 987 |
+
create_json_map(test_files_to_run, json_output_file)
|
| 988 |
+
|
| 989 |
+
print(f"\n### EXAMPLES TEST TO RUN ###\n{_print_list(examples_tests_to_run)}")
|
| 990 |
+
if len(examples_tests_to_run) > 0:
|
| 991 |
+
# We use `all` in the case `commit_flags["test_all"]` as well as in `create_circleci_config.py` for processing
|
| 992 |
+
if examples_tests_to_run == ["examples"]:
|
| 993 |
+
examples_tests_to_run = ["all"]
|
| 994 |
+
example_file = Path(output_file).parent / "examples_test_list.txt"
|
| 995 |
+
with open(example_file, "w", encoding="utf-8") as f:
|
| 996 |
+
f.write(" ".join(examples_tests_to_run))
|
| 997 |
+
|
| 998 |
+
|
| 999 |
+
def filter_tests(output_file: str, filters: List[str]):
|
| 1000 |
+
"""
|
| 1001 |
+
Reads the content of the output file and filters out all the tests in a list of given folders.
|
| 1002 |
+
|
| 1003 |
+
Args:
|
| 1004 |
+
output_file (`str` or `os.PathLike`): The path to the output file of the tests fetcher.
|
| 1005 |
+
filters (`List[str]`): A list of folders to filter.
|
| 1006 |
+
"""
|
| 1007 |
+
if not os.path.isfile(output_file):
|
| 1008 |
+
print("No test file found.")
|
| 1009 |
+
return
|
| 1010 |
+
with open(output_file, "r", encoding="utf-8") as f:
|
| 1011 |
+
test_files = f.read().split(" ")
|
| 1012 |
+
|
| 1013 |
+
if len(test_files) == 0 or test_files == [""]:
|
| 1014 |
+
print("No tests to filter.")
|
| 1015 |
+
return
|
| 1016 |
+
|
| 1017 |
+
if test_files == ["tests"]:
|
| 1018 |
+
test_files = [os.path.join("tests", f) for f in os.listdir("tests") if f not in ["__init__.py"] + filters]
|
| 1019 |
+
else:
|
| 1020 |
+
test_files = [f for f in test_files if f.split(os.path.sep)[1] not in filters]
|
| 1021 |
+
|
| 1022 |
+
with open(output_file, "w", encoding="utf-8") as f:
|
| 1023 |
+
f.write(" ".join(test_files))
|
| 1024 |
+
|
| 1025 |
+
|
| 1026 |
+
def parse_commit_message(commit_message: str) -> Dict[str, bool]:
|
| 1027 |
+
"""
|
| 1028 |
+
Parses the commit message to detect if a command is there to skip, force all or part of the CI.
|
| 1029 |
+
|
| 1030 |
+
Args:
|
| 1031 |
+
commit_message (`str`): The commit message of the current commit.
|
| 1032 |
+
|
| 1033 |
+
Returns:
|
| 1034 |
+
`Dict[str, bool]`: A dictionary of strings to bools with keys the following keys: `"skip"`,
|
| 1035 |
+
`"test_all_models"` and `"test_all"`.
|
| 1036 |
+
"""
|
| 1037 |
+
if commit_message is None:
|
| 1038 |
+
return {"skip": False, "no_filter": False, "test_all": False}
|
| 1039 |
+
|
| 1040 |
+
command_search = re.search(r"\[([^\]]*)\]", commit_message)
|
| 1041 |
+
if command_search is not None:
|
| 1042 |
+
command = command_search.groups()[0]
|
| 1043 |
+
command = command.lower().replace("-", " ").replace("_", " ")
|
| 1044 |
+
skip = command in ["ci skip", "skip ci", "circleci skip", "skip circleci"]
|
| 1045 |
+
no_filter = set(command.split(" ")) == {"no", "filter"}
|
| 1046 |
+
test_all = set(command.split(" ")) == {"test", "all"}
|
| 1047 |
+
return {"skip": skip, "no_filter": no_filter, "test_all": test_all}
|
| 1048 |
+
else:
|
| 1049 |
+
return {"skip": False, "no_filter": False, "test_all": False}
|
| 1050 |
+
|
| 1051 |
+
|
| 1052 |
+
if __name__ == "__main__":
|
| 1053 |
+
parser = argparse.ArgumentParser()
|
| 1054 |
+
parser.add_argument(
|
| 1055 |
+
"--output_file", type=str, default="test_list.txt", help="Where to store the list of tests to run"
|
| 1056 |
+
)
|
| 1057 |
+
parser.add_argument(
|
| 1058 |
+
"--json_output_file",
|
| 1059 |
+
type=str,
|
| 1060 |
+
default="test_map.json",
|
| 1061 |
+
help="Where to store the tests to run in a dictionary format mapping test categories to test files",
|
| 1062 |
+
)
|
| 1063 |
+
parser.add_argument(
|
| 1064 |
+
"--diff_with_last_commit",
|
| 1065 |
+
action="store_true",
|
| 1066 |
+
help="To fetch the tests between the current commit and the last commit",
|
| 1067 |
+
)
|
| 1068 |
+
parser.add_argument(
|
| 1069 |
+
"--filter_tests",
|
| 1070 |
+
action="store_true",
|
| 1071 |
+
help="Will filter the pipeline/repo utils tests outside of the generated list of tests.",
|
| 1072 |
+
)
|
| 1073 |
+
parser.add_argument(
|
| 1074 |
+
"--print_dependencies_of",
|
| 1075 |
+
type=str,
|
| 1076 |
+
help="Will only print the tree of modules depending on the file passed.",
|
| 1077 |
+
default=None,
|
| 1078 |
+
)
|
| 1079 |
+
parser.add_argument(
|
| 1080 |
+
"--commit_message",
|
| 1081 |
+
type=str,
|
| 1082 |
+
help="The commit message (which could contain a command to force all tests or skip the CI).",
|
| 1083 |
+
default=None,
|
| 1084 |
+
)
|
| 1085 |
+
args = parser.parse_args()
|
| 1086 |
+
if args.print_dependencies_of is not None:
|
| 1087 |
+
print_tree_deps_of(args.print_dependencies_of)
|
| 1088 |
+
else:
|
| 1089 |
+
repo = Repo(PATH_TO_REPO)
|
| 1090 |
+
commit_message = repo.head.commit.message
|
| 1091 |
+
commit_flags = parse_commit_message(commit_message)
|
| 1092 |
+
if commit_flags["skip"]:
|
| 1093 |
+
print("Force-skipping the CI")
|
| 1094 |
+
quit()
|
| 1095 |
+
if commit_flags["no_filter"]:
|
| 1096 |
+
print("Running all tests fetched without filtering.")
|
| 1097 |
+
if commit_flags["test_all"]:
|
| 1098 |
+
print("Force-launching all tests")
|
| 1099 |
+
|
| 1100 |
+
diff_with_last_commit = args.diff_with_last_commit
|
| 1101 |
+
if not diff_with_last_commit and not repo.head.is_detached and repo.head.ref == repo.refs.main:
|
| 1102 |
+
print("main branch detected, fetching tests against last commit.")
|
| 1103 |
+
diff_with_last_commit = True
|
| 1104 |
+
|
| 1105 |
+
if not commit_flags["test_all"]:
|
| 1106 |
+
try:
|
| 1107 |
+
infer_tests_to_run(
|
| 1108 |
+
args.output_file,
|
| 1109 |
+
diff_with_last_commit=diff_with_last_commit,
|
| 1110 |
+
json_output_file=args.json_output_file,
|
| 1111 |
+
)
|
| 1112 |
+
filter_tests(args.output_file, ["repo_utils"])
|
| 1113 |
+
update_test_map_with_core_pipelines(json_output_file=args.json_output_file)
|
| 1114 |
+
|
| 1115 |
+
except Exception as e:
|
| 1116 |
+
print(f"\nError when trying to grab the relevant tests: {e}\n\nRunning all tests.")
|
| 1117 |
+
commit_flags["test_all"] = True
|
| 1118 |
+
|
| 1119 |
+
if commit_flags["test_all"]:
|
| 1120 |
+
with open(args.output_file, "w", encoding="utf-8") as f:
|
| 1121 |
+
f.write("tests")
|
| 1122 |
+
example_file = Path(args.output_file).parent / "examples_test_list.txt"
|
| 1123 |
+
with open(example_file, "w", encoding="utf-8") as f:
|
| 1124 |
+
f.write("all")
|
| 1125 |
+
|
| 1126 |
+
test_files_to_run = get_all_tests()
|
| 1127 |
+
create_json_map(test_files_to_run, args.json_output_file)
|
| 1128 |
+
update_test_map_with_core_pipelines(json_output_file=args.json_output_file)
|