xiaoanyu123 commited on
Commit
e291b5b
·
verified ·
1 Parent(s): 8f6d75a

Add files using upload-large-folder tool

Browse files
.gitattributes CHANGED
@@ -266,3 +266,15 @@ pythonProject/.venv/Scripts/diffusers-cli.exe filter=lfs diff=lfs merge=lfs -tex
266
  pythonProject/.venv/Scripts/check-node.exe filter=lfs diff=lfs merge=lfs -text
267
  pythonProject/.venv/Scripts/hf.exe filter=lfs diff=lfs merge=lfs -text
268
  pythonProject/.venv/Scripts/huggingface-cli.exe filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  pythonProject/.venv/Scripts/check-node.exe filter=lfs diff=lfs merge=lfs -text
267
  pythonProject/.venv/Scripts/hf.exe filter=lfs diff=lfs merge=lfs -text
268
  pythonProject/.venv/Scripts/huggingface-cli.exe filter=lfs diff=lfs merge=lfs -text
269
+ pythonProject/.venv/Scripts/numpy-config.exe filter=lfs diff=lfs merge=lfs -text
270
+ pythonProject/.venv/Scripts/isympy.exe filter=lfs diff=lfs merge=lfs -text
271
+ pythonProject/.venv/Scripts/pip3.10.exe filter=lfs diff=lfs merge=lfs -text
272
+ pythonProject/.venv/Scripts/pip.exe filter=lfs diff=lfs merge=lfs -text
273
+ pythonProject/.venv/Scripts/pip3.exe filter=lfs diff=lfs merge=lfs -text
274
+ pythonProject/.venv/Scripts/python.exe filter=lfs diff=lfs merge=lfs -text
275
+ pythonProject/.venv/Scripts/tiny-agents.exe filter=lfs diff=lfs merge=lfs -text
276
+ pythonProject/.venv/Scripts/pythonw.exe filter=lfs diff=lfs merge=lfs -text
277
+ pythonProject/.venv/Scripts/torchfrtrace.exe filter=lfs diff=lfs merge=lfs -text
278
+ pythonProject/.venv/Scripts/torchrun.exe filter=lfs diff=lfs merge=lfs -text
279
+ pythonProject/.venv/Scripts/tqdm.exe filter=lfs diff=lfs merge=lfs -text
280
+ pythonProject/diffusers-main/docs/source/en/imgs/access_request.png filter=lfs diff=lfs merge=lfs -text
pythonProject/.venv/Scripts/isympy.exe ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7e0939da22ca6f0052870e8dc3da42a78679a46b194e96ed827be31c9a834471
3
+ size 108395
pythonProject/.venv/Scripts/numpy-config.exe ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c6c0c18062b265bbf63daa978ed69ff5cbb63cda8dcaae7787b563d61d0cd29e
3
+ size 108406
pythonProject/.venv/Scripts/pip.exe ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df129c9009bd7ba09b466cbafbd196813830eef8c03f3009c1f30946bb377acb
3
+ size 108411
pythonProject/.venv/Scripts/pip3.10.exe ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df129c9009bd7ba09b466cbafbd196813830eef8c03f3009c1f30946bb377acb
3
+ size 108411
pythonProject/.venv/Scripts/pip3.exe ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df129c9009bd7ba09b466cbafbd196813830eef8c03f3009c1f30946bb377acb
3
+ size 108411
pythonProject/.venv/Scripts/python.exe ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b2c836c52cdf063180b9ee76f67ac42946101b79ac457f3494035a67c090d961
3
+ size 268568
pythonProject/.venv/Scripts/pythonw.exe ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1c37c6c3c14074b69b3b8543ed5edcf6666b876c8f73672e52b58f6bbc5dd3ee
3
+ size 257304
pythonProject/.venv/Scripts/tiny-agents.exe ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cfbbcb52bbb876ef485061f31d62ec773a51726d305409a0147aab6b4eaf2a9f
3
+ size 108421
pythonProject/.venv/Scripts/torchfrtrace.exe ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4b3d9da8a20090cb5417548b5098b6dba389755916c493e460a782fdb791c3e2
3
+ size 108419
pythonProject/.venv/Scripts/torchrun.exe ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f4d393a13ca96560946983d8f7ed998c83726bc94584ddfd92942bea7de12565
3
+ size 108410
pythonProject/.venv/Scripts/tqdm.exe ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e8b93dfc2cd0722449177b0dd541c8b5ce17d3f3153f6a0ee7c0778ce419cd3b
3
+ size 108397
pythonProject/diffusers-main/docs/source/en/imgs/access_request.png ADDED

Git LFS Details

  • SHA256: 9688dabf75e180590251cd1f75d18966f9c94d5d6584bc7d0278b698c175c61f
  • Pointer size: 131 Bytes
  • Size of remote file: 105 kB
pythonProject/diffusers-main/utils/notify_benchmarking_status.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import argparse
17
+ import os
18
+
19
+ import requests
20
+
21
+
22
+ # Configuration
23
+ GITHUB_REPO = "huggingface/diffusers"
24
+ GITHUB_RUN_ID = os.getenv("GITHUB_RUN_ID")
25
+ SLACK_WEBHOOK_URL = os.getenv("SLACK_WEBHOOK_URL")
26
+
27
+
28
+ def main(args):
29
+ action_url = f"https://github.com/{GITHUB_REPO}/actions/runs/{GITHUB_RUN_ID}"
30
+ if args.status == "success":
31
+ hub_path = "https://huggingface.co/datasets/diffusers/benchmarks/blob/main/collated_results.csv"
32
+ message = (
33
+ "✅ New benchmark workflow successfully run.\n"
34
+ f"🕸️ GitHub Action URL: {action_url}.\n"
35
+ f"🤗 Check out the benchmarks here: {hub_path}."
36
+ )
37
+ else:
38
+ message = (
39
+ "❌ Something wrong happened in the benchmarking workflow.\n"
40
+ f"Check out the GitHub Action to know more: {action_url}."
41
+ )
42
+
43
+ payload = {"text": message}
44
+ response = requests.post(SLACK_WEBHOOK_URL, json=payload)
45
+
46
+ if response.status_code == 200:
47
+ print("Notification sent to Slack successfully.")
48
+ else:
49
+ print("Failed to send notification to Slack.")
50
+
51
+
52
+ if __name__ == "__main__":
53
+ parser = argparse.ArgumentParser()
54
+ parser.add_argument("--status", type=str, default="success", choices=["success", "failure"])
55
+ args = parser.parse_args()
56
+ main(args)
pythonProject/diffusers-main/utils/notify_community_pipelines_mirror.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import argparse
17
+ import os
18
+
19
+ import requests
20
+
21
+
22
+ # Configuration
23
+ GITHUB_REPO = "huggingface/diffusers"
24
+ GITHUB_RUN_ID = os.getenv("GITHUB_RUN_ID")
25
+ SLACK_WEBHOOK_URL = os.getenv("SLACK_WEBHOOK_URL")
26
+ PATH_IN_REPO = os.getenv("PATH_IN_REPO")
27
+
28
+
29
+ def main(args):
30
+ action_url = f"https://github.com/{GITHUB_REPO}/actions/runs/{GITHUB_RUN_ID}"
31
+ if args.status == "success":
32
+ hub_path = f"https://huggingface.co/datasets/diffusers/community-pipelines-mirror/tree/main/{PATH_IN_REPO}"
33
+ message = (
34
+ "✅ Community pipelines successfully mirrored.\n"
35
+ f"🕸️ GitHub Action URL: {action_url}.\n"
36
+ f"🤗 Hub location: {hub_path}."
37
+ )
38
+ else:
39
+ message = f"❌ Something wrong happened. Check out the GitHub Action to know more: {action_url}."
40
+
41
+ payload = {"text": message}
42
+ response = requests.post(SLACK_WEBHOOK_URL, json=payload)
43
+
44
+ if response.status_code == 200:
45
+ print("Notification sent to Slack successfully.")
46
+ else:
47
+ print("Failed to send notification to Slack.")
48
+
49
+
50
+ if __name__ == "__main__":
51
+ parser = argparse.ArgumentParser()
52
+ parser.add_argument("--status", type=str, default="success", choices=["success", "failure"])
53
+ args = parser.parse_args()
54
+ main(args)
pythonProject/diffusers-main/utils/notify_slack_about_release.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+
18
+ import requests
19
+
20
+
21
+ # Configuration
22
+ LIBRARY_NAME = "diffusers"
23
+ GITHUB_REPO = "huggingface/diffusers"
24
+ SLACK_WEBHOOK_URL = os.getenv("SLACK_WEBHOOK_URL")
25
+
26
+
27
+ def check_pypi_for_latest_release(library_name):
28
+ """Check PyPI for the latest release of the library."""
29
+ response = requests.get(f"https://pypi.org/pypi/{library_name}/json", timeout=60)
30
+ if response.status_code == 200:
31
+ data = response.json()
32
+ return data["info"]["version"]
33
+ else:
34
+ print("Failed to fetch library details from PyPI.")
35
+ return None
36
+
37
+
38
+ def get_github_release_info(github_repo):
39
+ """Fetch the latest release info from GitHub."""
40
+ url = f"https://api.github.com/repos/{github_repo}/releases/latest"
41
+ response = requests.get(url, timeout=60)
42
+
43
+ if response.status_code == 200:
44
+ data = response.json()
45
+ return {"tag_name": data["tag_name"], "url": data["html_url"], "release_time": data["published_at"]}
46
+
47
+ else:
48
+ print("Failed to fetch release info from GitHub.")
49
+ return None
50
+
51
+
52
+ def notify_slack(webhook_url, library_name, version, release_info):
53
+ """Send a notification to a Slack channel."""
54
+ message = (
55
+ f"🚀 New release for {library_name} available: version **{version}** 🎉\n"
56
+ f"📜 Release Notes: {release_info['url']}\n"
57
+ f"⏱️ Release time: {release_info['release_time']}"
58
+ )
59
+ payload = {"text": message}
60
+ response = requests.post(webhook_url, json=payload)
61
+
62
+ if response.status_code == 200:
63
+ print("Notification sent to Slack successfully.")
64
+ else:
65
+ print("Failed to send notification to Slack.")
66
+
67
+
68
+ def main():
69
+ latest_version = check_pypi_for_latest_release(LIBRARY_NAME)
70
+ release_info = get_github_release_info(GITHUB_REPO)
71
+ parsed_version = release_info["tag_name"].replace("v", "")
72
+
73
+ if latest_version and release_info and latest_version == parsed_version:
74
+ notify_slack(SLACK_WEBHOOK_URL, LIBRARY_NAME, latest_version, release_info)
75
+ else:
76
+ print(f"{latest_version=}, {release_info=}, {parsed_version=}")
77
+ raise ValueError("There were some problems.")
78
+
79
+
80
+ if __name__ == "__main__":
81
+ main()
pythonProject/diffusers-main/utils/overwrite_expected_slice.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import argparse
16
+ from collections import defaultdict
17
+
18
+
19
+ def overwrite_file(file, class_name, test_name, correct_line, done_test):
20
+ _id = f"{file}_{class_name}_{test_name}"
21
+ done_test[_id] += 1
22
+
23
+ with open(file, "r") as f:
24
+ lines = f.readlines()
25
+
26
+ class_regex = f"class {class_name}("
27
+ test_regex = f"{4 * ' '}def {test_name}("
28
+ line_begin_regex = f"{8 * ' '}{correct_line.split()[0]}"
29
+ another_line_begin_regex = f"{16 * ' '}{correct_line.split()[0]}"
30
+ in_class = False
31
+ in_func = False
32
+ in_line = False
33
+ insert_line = False
34
+ count = 0
35
+ spaces = 0
36
+
37
+ new_lines = []
38
+ for line in lines:
39
+ if line.startswith(class_regex):
40
+ in_class = True
41
+ elif in_class and line.startswith(test_regex):
42
+ in_func = True
43
+ elif in_class and in_func and (line.startswith(line_begin_regex) or line.startswith(another_line_begin_regex)):
44
+ spaces = len(line.split(correct_line.split()[0])[0])
45
+ count += 1
46
+
47
+ if count == done_test[_id]:
48
+ in_line = True
49
+
50
+ if in_class and in_func and in_line:
51
+ if ")" not in line:
52
+ continue
53
+ else:
54
+ insert_line = True
55
+
56
+ if in_class and in_func and in_line and insert_line:
57
+ new_lines.append(f"{spaces * ' '}{correct_line}")
58
+ in_class = in_func = in_line = insert_line = False
59
+ else:
60
+ new_lines.append(line)
61
+
62
+ with open(file, "w") as f:
63
+ for line in new_lines:
64
+ f.write(line)
65
+
66
+
67
+ def main(correct, fail=None):
68
+ if fail is not None:
69
+ with open(fail, "r") as f:
70
+ test_failures = {l.strip() for l in f.readlines()}
71
+ else:
72
+ test_failures = None
73
+
74
+ with open(correct, "r") as f:
75
+ correct_lines = f.readlines()
76
+
77
+ done_tests = defaultdict(int)
78
+ for line in correct_lines:
79
+ file, class_name, test_name, correct_line = line.split("::")
80
+ if test_failures is None or "::".join([file, class_name, test_name]) in test_failures:
81
+ overwrite_file(file, class_name, test_name, correct_line, done_tests)
82
+
83
+
84
+ if __name__ == "__main__":
85
+ parser = argparse.ArgumentParser()
86
+ parser.add_argument("--correct_filename", help="filename of tests with expected result")
87
+ parser.add_argument("--fail_filename", help="filename of test failures", type=str, default=None)
88
+ args = parser.parse_args()
89
+
90
+ main(args.correct_filename, args.fail_filename)
pythonProject/diffusers-main/utils/print_env.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # coding=utf-8
4
+ # Copyright 2025 The HuggingFace Inc. team.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ # this script dumps information about the environment
19
+
20
+ import os
21
+ import platform
22
+ import sys
23
+
24
+
25
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
26
+
27
+ print("Python version:", sys.version)
28
+
29
+ print("OS platform:", platform.platform())
30
+ print("OS architecture:", platform.machine())
31
+ try:
32
+ import psutil
33
+
34
+ vm = psutil.virtual_memory()
35
+ total_gb = vm.total / (1024**3)
36
+ available_gb = vm.available / (1024**3)
37
+ print(f"Total RAM: {total_gb:.2f} GB")
38
+ print(f"Available RAM: {available_gb:.2f} GB")
39
+ except ImportError:
40
+ pass
41
+
42
+ try:
43
+ import torch
44
+
45
+ print("Torch version:", torch.__version__)
46
+ print("Cuda available:", torch.cuda.is_available())
47
+ if torch.cuda.is_available():
48
+ print("Cuda version:", torch.version.cuda)
49
+ print("CuDNN version:", torch.backends.cudnn.version())
50
+ print("Number of GPUs available:", torch.cuda.device_count())
51
+ device_properties = torch.cuda.get_device_properties(0)
52
+ total_memory = device_properties.total_memory / (1024**3)
53
+ print(f"CUDA memory: {total_memory} GB")
54
+
55
+ print("XPU available:", hasattr(torch, "xpu") and torch.xpu.is_available())
56
+ if hasattr(torch, "xpu") and torch.xpu.is_available():
57
+ print("XPU model:", torch.xpu.get_device_properties(0).name)
58
+ print("XPU compiler version:", torch.version.xpu)
59
+ print("Number of XPUs available:", torch.xpu.device_count())
60
+ device_properties = torch.xpu.get_device_properties(0)
61
+ total_memory = device_properties.total_memory / (1024**3)
62
+ print(f"XPU memory: {total_memory} GB")
63
+
64
+
65
+ except ImportError:
66
+ print("Torch version:", None)
67
+
68
+ try:
69
+ import transformers
70
+
71
+ print("transformers version:", transformers.__version__)
72
+ except ImportError:
73
+ print("transformers version:", None)
pythonProject/diffusers-main/utils/release.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import argparse
17
+ import os
18
+ import re
19
+
20
+ import packaging.version
21
+
22
+
23
+ PATH_TO_EXAMPLES = "examples/"
24
+ REPLACE_PATTERNS = {
25
+ "examples": (re.compile(r'^check_min_version\("[^"]+"\)\s*$', re.MULTILINE), 'check_min_version("VERSION")\n'),
26
+ "init": (re.compile(r'^__version__\s+=\s+"([^"]+)"\s*$', re.MULTILINE), '__version__ = "VERSION"\n'),
27
+ "setup": (re.compile(r'^(\s*)version\s*=\s*"[^"]+",', re.MULTILINE), r'\1version="VERSION",'),
28
+ "doc": (re.compile(r'^(\s*)release\s*=\s*"[^"]+"$', re.MULTILINE), 'release = "VERSION"\n'),
29
+ }
30
+ REPLACE_FILES = {
31
+ "init": "src/diffusers/__init__.py",
32
+ "setup": "setup.py",
33
+ }
34
+ README_FILE = "README.md"
35
+
36
+
37
+ def update_version_in_file(fname, version, pattern):
38
+ """Update the version in one file using a specific pattern."""
39
+ with open(fname, "r", encoding="utf-8", newline="\n") as f:
40
+ code = f.read()
41
+ re_pattern, replace = REPLACE_PATTERNS[pattern]
42
+ replace = replace.replace("VERSION", version)
43
+ code = re_pattern.sub(replace, code)
44
+ with open(fname, "w", encoding="utf-8", newline="\n") as f:
45
+ f.write(code)
46
+
47
+
48
+ def update_version_in_examples(version):
49
+ """Update the version in all examples files."""
50
+ for folder, directories, fnames in os.walk(PATH_TO_EXAMPLES):
51
+ # Removing some of the folders with non-actively maintained examples from the walk
52
+ if "research_projects" in directories:
53
+ directories.remove("research_projects")
54
+ if "legacy" in directories:
55
+ directories.remove("legacy")
56
+ for fname in fnames:
57
+ if fname.endswith(".py"):
58
+ update_version_in_file(os.path.join(folder, fname), version, pattern="examples")
59
+
60
+
61
+ def global_version_update(version, patch=False):
62
+ """Update the version in all needed files."""
63
+ for pattern, fname in REPLACE_FILES.items():
64
+ update_version_in_file(fname, version, pattern)
65
+ if not patch:
66
+ update_version_in_examples(version)
67
+
68
+
69
+ def clean_main_ref_in_model_list():
70
+ """Replace the links from main doc tp stable doc in the model list of the README."""
71
+ # If the introduction or the conclusion of the list change, the prompts may need to be updated.
72
+ _start_prompt = "🤗 Transformers currently provides the following architectures"
73
+ _end_prompt = "1. Want to contribute a new model?"
74
+ with open(README_FILE, "r", encoding="utf-8", newline="\n") as f:
75
+ lines = f.readlines()
76
+
77
+ # Find the start of the list.
78
+ start_index = 0
79
+ while not lines[start_index].startswith(_start_prompt):
80
+ start_index += 1
81
+ start_index += 1
82
+
83
+ index = start_index
84
+ # Update the lines in the model list.
85
+ while not lines[index].startswith(_end_prompt):
86
+ if lines[index].startswith("1."):
87
+ lines[index] = lines[index].replace(
88
+ "https://huggingface.co/docs/diffusers/main/model_doc",
89
+ "https://huggingface.co/docs/diffusers/model_doc",
90
+ )
91
+ index += 1
92
+
93
+ with open(README_FILE, "w", encoding="utf-8", newline="\n") as f:
94
+ f.writelines(lines)
95
+
96
+
97
+ def get_version():
98
+ """Reads the current version in the __init__."""
99
+ with open(REPLACE_FILES["init"], "r") as f:
100
+ code = f.read()
101
+ default_version = REPLACE_PATTERNS["init"][0].search(code).groups()[0]
102
+ return packaging.version.parse(default_version)
103
+
104
+
105
+ def pre_release_work(patch=False):
106
+ """Do all the necessary pre-release steps."""
107
+ # First let's get the default version: base version if we are in dev, bump minor otherwise.
108
+ default_version = get_version()
109
+ if patch and default_version.is_devrelease:
110
+ raise ValueError("Can't create a patch version from the dev branch, checkout a released version!")
111
+ if default_version.is_devrelease:
112
+ default_version = default_version.base_version
113
+ elif patch:
114
+ default_version = f"{default_version.major}.{default_version.minor}.{default_version.micro + 1}"
115
+ else:
116
+ default_version = f"{default_version.major}.{default_version.minor + 1}.0"
117
+
118
+ # Now let's ask nicely if that's the right one.
119
+ version = input(f"Which version are you releasing? [{default_version}]")
120
+ if len(version) == 0:
121
+ version = default_version
122
+
123
+ print(f"Updating version to {version}.")
124
+ global_version_update(version, patch=patch)
125
+
126
+
127
+ # if not patch:
128
+ # print("Cleaning main README, don't forget to run `make fix-copies`.")
129
+ # clean_main_ref_in_model_list()
130
+
131
+
132
+ def post_release_work():
133
+ """Do all the necessary post-release steps."""
134
+ # First let's get the current version
135
+ current_version = get_version()
136
+ dev_version = f"{current_version.major}.{current_version.minor + 1}.0.dev0"
137
+ current_version = current_version.base_version
138
+
139
+ # Check with the user we got that right.
140
+ version = input(f"Which version are we developing now? [{dev_version}]")
141
+ if len(version) == 0:
142
+ version = dev_version
143
+
144
+ print(f"Updating version to {version}.")
145
+ global_version_update(version)
146
+
147
+
148
+ # print("Cleaning main README, don't forget to run `make fix-copies`.")
149
+ # clean_main_ref_in_model_list()
150
+
151
+
152
+ if __name__ == "__main__":
153
+ parser = argparse.ArgumentParser()
154
+ parser.add_argument("--post_release", action="store_true", help="Whether this is pre or post release.")
155
+ parser.add_argument("--patch", action="store_true", help="Whether or not this is a patch release.")
156
+ args = parser.parse_args()
157
+ if not args.post_release:
158
+ pre_release_work(patch=args.patch)
159
+ elif args.patch:
160
+ print("Nothing to do after a patch :-)")
161
+ else:
162
+ post_release_work()
pythonProject/diffusers-main/utils/stale.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team, the AllenNLP library authors. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Script to close stale issue. Taken in part from the AllenNLP repository.
16
+ https://github.com/allenai/allennlp.
17
+ """
18
+
19
+ import os
20
+ from datetime import datetime as dt
21
+ from datetime import timezone
22
+
23
+ from github import Github
24
+
25
+
26
+ LABELS_TO_EXEMPT = [
27
+ "close-to-merge",
28
+ "good first issue",
29
+ "good second issue",
30
+ "good difficult issue",
31
+ "enhancement",
32
+ "new pipeline/model",
33
+ "new scheduler",
34
+ "wip",
35
+ ]
36
+
37
+
38
+ def main():
39
+ g = Github(os.environ["GITHUB_TOKEN"])
40
+ repo = g.get_repo("huggingface/diffusers")
41
+ open_issues = repo.get_issues(state="open")
42
+
43
+ for issue in open_issues:
44
+ labels = [label.name.lower() for label in issue.get_labels()]
45
+ if "stale" in labels:
46
+ comments = sorted(issue.get_comments(), key=lambda i: i.created_at, reverse=True)
47
+ last_comment = comments[0] if len(comments) > 0 else None
48
+ if last_comment is not None and last_comment.user.login != "github-actions[bot]":
49
+ # Opens the issue if someone other than Stalebot commented.
50
+ issue.edit(state="open")
51
+ issue.remove_from_labels("stale")
52
+ elif (
53
+ (dt.now(timezone.utc) - issue.updated_at).days > 23
54
+ and (dt.now(timezone.utc) - issue.created_at).days >= 30
55
+ and not any(label in LABELS_TO_EXEMPT for label in labels)
56
+ ):
57
+ # Post a Stalebot notification after 23 days of inactivity.
58
+ issue.create_comment(
59
+ "This issue has been automatically marked as stale because it has not had "
60
+ "recent activity. If you think this still needs to be addressed "
61
+ "please comment on this thread.\n\nPlease note that issues that do not follow the "
62
+ "[contributing guidelines](https://github.com/huggingface/diffusers/blob/main/CONTRIBUTING.md) "
63
+ "are likely to be ignored."
64
+ )
65
+ issue.add_to_labels("stale")
66
+
67
+
68
+ if __name__ == "__main__":
69
+ main()
pythonProject/diffusers-main/utils/tests_fetcher.py ADDED
@@ -0,0 +1,1128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """
17
+ Welcome to tests_fetcher V2.
18
+
19
+ This util is designed to fetch tests to run on a PR so that only the tests impacted by the modifications are run, and
20
+ when too many models are being impacted, only run the tests of a subset of core models. It works like this.
21
+
22
+ Stage 1: Identify the modified files. For jobs that run on the main branch, it's just the diff with the last commit.
23
+ On a PR, this takes all the files from the branching point to the current commit (so all modifications in a PR, not
24
+ just the last commit) but excludes modifications that are on docstrings or comments only.
25
+
26
+ Stage 2: Extract the tests to run. This is done by looking at the imports in each module and test file: if module A
27
+ imports module B, then changing module B impacts module A, so the tests using module A should be run. We thus get the
28
+ dependencies of each model and then recursively builds the 'reverse' map of dependencies to get all modules and tests
29
+ impacted by a given file. We then only keep the tests (and only the core models tests if there are too many modules).
30
+
31
+ Caveats:
32
+ - This module only filters tests by files (not individual tests) so it's better to have tests for different things
33
+ in different files.
34
+ - This module assumes inits are just importing things, not really building objects, so it's better to structure
35
+ them this way and move objects building in separate submodules.
36
+
37
+ Usage:
38
+
39
+ Base use to fetch the tests in a pull request
40
+
41
+ ```bash
42
+ python utils/tests_fetcher.py
43
+ ```
44
+
45
+ Base use to fetch the tests on a the main branch (with diff from the last commit):
46
+
47
+ ```bash
48
+ python utils/tests_fetcher.py --diff_with_last_commit
49
+ ```
50
+ """
51
+
52
+ import argparse
53
+ import collections
54
+ import json
55
+ import os
56
+ import re
57
+ from contextlib import contextmanager
58
+ from pathlib import Path
59
+ from typing import Dict, List, Optional, Tuple, Union
60
+
61
+ from git import Repo
62
+
63
+
64
+ PATH_TO_REPO = Path(__file__).parent.parent.resolve()
65
+ PATH_TO_EXAMPLES = PATH_TO_REPO / "examples"
66
+ PATH_TO_DIFFUSERS = PATH_TO_REPO / "src/diffusers"
67
+ PATH_TO_TESTS = PATH_TO_REPO / "tests"
68
+
69
+ # Ignore fixtures in tests folder
70
+ # Ignore lora since they are always tested
71
+ MODULES_TO_IGNORE = ["fixtures", "lora"]
72
+
73
+ IMPORTANT_PIPELINES = [
74
+ "controlnet",
75
+ "stable_diffusion",
76
+ "stable_diffusion_2",
77
+ "stable_diffusion_xl",
78
+ "stable_video_diffusion",
79
+ "deepfloyd_if",
80
+ "kandinsky",
81
+ "kandinsky2_2",
82
+ "text_to_video_synthesis",
83
+ "wuerstchen",
84
+ ]
85
+
86
+
87
+ @contextmanager
88
+ def checkout_commit(repo: Repo, commit_id: str):
89
+ """
90
+ Context manager that checks out a given commit when entered, but gets back to the reference it was at on exit.
91
+
92
+ Args:
93
+ repo (`git.Repo`): A git repository (for instance the Transformers repo).
94
+ commit_id (`str`): The commit reference to checkout inside the context manager.
95
+ """
96
+ current_head = repo.head.commit if repo.head.is_detached else repo.head.ref
97
+
98
+ try:
99
+ repo.git.checkout(commit_id)
100
+ yield
101
+
102
+ finally:
103
+ repo.git.checkout(current_head)
104
+
105
+
106
+ def clean_code(content: str) -> str:
107
+ """
108
+ Remove docstrings, empty line or comments from some code (used to detect if a diff is real or only concern
109
+ comments or docstrings).
110
+
111
+ Args:
112
+ content (`str`): The code to clean
113
+
114
+ Returns:
115
+ `str`: The cleaned code.
116
+ """
117
+ # We need to deactivate autoformatting here to write escaped triple quotes (we cannot use real triple quotes or
118
+ # this would mess up the result if this function applied to this particular file).
119
+ # fmt: off
120
+ # Remove docstrings by splitting on triple " then triple ':
121
+ splits = content.split('\"\"\"')
122
+ content = "".join(splits[::2])
123
+ splits = content.split("\'\'\'")
124
+ # fmt: on
125
+ content = "".join(splits[::2])
126
+
127
+ # Remove empty lines and comments
128
+ lines_to_keep = []
129
+ for line in content.split("\n"):
130
+ # remove anything that is after a # sign.
131
+ line = re.sub("#.*$", "", line)
132
+ # remove white lines
133
+ if len(line) != 0 and not line.isspace():
134
+ lines_to_keep.append(line)
135
+ return "\n".join(lines_to_keep)
136
+
137
+
138
+ def keep_doc_examples_only(content: str) -> str:
139
+ """
140
+ Remove everything from the code content except the doc examples (used to determined if a diff should trigger doc
141
+ tests or not).
142
+
143
+ Args:
144
+ content (`str`): The code to clean
145
+
146
+ Returns:
147
+ `str`: The cleaned code.
148
+ """
149
+ # Keep doc examples only by splitting on triple "`"
150
+ splits = content.split("```")
151
+ # Add leading and trailing "```" so the navigation is easier when compared to the original input `content`
152
+ content = "```" + "```".join(splits[1::2]) + "```"
153
+
154
+ # Remove empty lines and comments
155
+ lines_to_keep = []
156
+ for line in content.split("\n"):
157
+ # remove anything that is after a # sign.
158
+ line = re.sub("#.*$", "", line)
159
+ # remove white lines
160
+ if len(line) != 0 and not line.isspace():
161
+ lines_to_keep.append(line)
162
+ return "\n".join(lines_to_keep)
163
+
164
+
165
+ def get_all_tests() -> List[str]:
166
+ """
167
+ Walks the `tests` folder to return a list of files/subfolders. This is used to split the tests to run when using
168
+ parallelism. The split is:
169
+
170
+ - folders under `tests`: (`tokenization`, `pipelines`, etc) except the subfolder `models` is excluded.
171
+ - folders under `tests/models`: `bert`, `gpt2`, etc.
172
+ - test files under `tests`: `test_modeling_common.py`, `test_tokenization_common.py`, etc.
173
+ """
174
+
175
+ # test folders/files directly under `tests` folder
176
+ tests = os.listdir(PATH_TO_TESTS)
177
+ tests = [f"tests/{f}" for f in tests if "__pycache__" not in f]
178
+ tests = sorted([f for f in tests if (PATH_TO_REPO / f).is_dir() or f.startswith("tests/test_")])
179
+
180
+ return tests
181
+
182
+
183
+ def diff_is_docstring_only(repo: Repo, branching_point: str, filename: str) -> bool:
184
+ """
185
+ Check if the diff is only in docstrings (or comments and whitespace) in a filename.
186
+
187
+ Args:
188
+ repo (`git.Repo`): A git repository (for instance the Transformers repo).
189
+ branching_point (`str`): The commit reference of where to compare for the diff.
190
+ filename (`str`): The filename where we want to know if the diff isonly in docstrings/comments.
191
+
192
+ Returns:
193
+ `bool`: Whether the diff is docstring/comments only or not.
194
+ """
195
+ folder = Path(repo.working_dir)
196
+ with checkout_commit(repo, branching_point):
197
+ with open(folder / filename, "r", encoding="utf-8") as f:
198
+ old_content = f.read()
199
+
200
+ with open(folder / filename, "r", encoding="utf-8") as f:
201
+ new_content = f.read()
202
+
203
+ old_content_clean = clean_code(old_content)
204
+ new_content_clean = clean_code(new_content)
205
+
206
+ return old_content_clean == new_content_clean
207
+
208
+
209
+ def diff_contains_doc_examples(repo: Repo, branching_point: str, filename: str) -> bool:
210
+ """
211
+ Check if the diff is only in code examples of the doc in a filename.
212
+
213
+ Args:
214
+ repo (`git.Repo`): A git repository (for instance the Transformers repo).
215
+ branching_point (`str`): The commit reference of where to compare for the diff.
216
+ filename (`str`): The filename where we want to know if the diff is only in codes examples.
217
+
218
+ Returns:
219
+ `bool`: Whether the diff is only in code examples of the doc or not.
220
+ """
221
+ folder = Path(repo.working_dir)
222
+ with checkout_commit(repo, branching_point):
223
+ with open(folder / filename, "r", encoding="utf-8") as f:
224
+ old_content = f.read()
225
+
226
+ with open(folder / filename, "r", encoding="utf-8") as f:
227
+ new_content = f.read()
228
+
229
+ old_content_clean = keep_doc_examples_only(old_content)
230
+ new_content_clean = keep_doc_examples_only(new_content)
231
+
232
+ return old_content_clean != new_content_clean
233
+
234
+
235
+ def get_diff(repo: Repo, base_commit: str, commits: List[str]) -> List[str]:
236
+ """
237
+ Get the diff between a base commit and one or several commits.
238
+
239
+ Args:
240
+ repo (`git.Repo`):
241
+ A git repository (for instance the Transformers repo).
242
+ base_commit (`str`):
243
+ The commit reference of where to compare for the diff. This is the current commit, not the branching point!
244
+ commits (`List[str]`):
245
+ The list of commits with which to compare the repo at `base_commit` (so the branching point).
246
+
247
+ Returns:
248
+ `List[str]`: The list of Python files with a diff (files added, renamed or deleted are always returned, files
249
+ modified are returned if the diff in the file is not only in docstrings or comments, see
250
+ `diff_is_docstring_only`).
251
+ """
252
+ print("\n### DIFF ###\n")
253
+ code_diff = []
254
+ for commit in commits:
255
+ for diff_obj in commit.diff(base_commit):
256
+ # We always add new python files
257
+ if diff_obj.change_type == "A" and diff_obj.b_path.endswith(".py"):
258
+ code_diff.append(diff_obj.b_path)
259
+ # We check that deleted python files won't break corresponding tests.
260
+ elif diff_obj.change_type == "D" and diff_obj.a_path.endswith(".py"):
261
+ code_diff.append(diff_obj.a_path)
262
+ # Now for modified files
263
+ elif diff_obj.change_type in ["M", "R"] and diff_obj.b_path.endswith(".py"):
264
+ # In case of renames, we'll look at the tests using both the old and new name.
265
+ if diff_obj.a_path != diff_obj.b_path:
266
+ code_diff.extend([diff_obj.a_path, diff_obj.b_path])
267
+ else:
268
+ # Otherwise, we check modifications are in code and not docstrings.
269
+ if diff_is_docstring_only(repo, commit, diff_obj.b_path):
270
+ print(f"Ignoring diff in {diff_obj.b_path} as it only concerns docstrings or comments.")
271
+ else:
272
+ code_diff.append(diff_obj.a_path)
273
+
274
+ return code_diff
275
+
276
+
277
+ def get_modified_python_files(diff_with_last_commit: bool = False) -> List[str]:
278
+ """
279
+ Return a list of python files that have been modified between:
280
+
281
+ - the current head and the main branch if `diff_with_last_commit=False` (default)
282
+ - the current head and its parent commit otherwise.
283
+
284
+ Returns:
285
+ `List[str]`: The list of Python files with a diff (files added, renamed or deleted are always returned, files
286
+ modified are returned if the diff in the file is not only in docstrings or comments, see
287
+ `diff_is_docstring_only`).
288
+ """
289
+ repo = Repo(PATH_TO_REPO)
290
+
291
+ if not diff_with_last_commit:
292
+ # Need to fetch refs for main using remotes when running with github actions.
293
+ upstream_main = repo.remotes.origin.refs.main
294
+
295
+ print(f"main is at {upstream_main.commit}")
296
+ print(f"Current head is at {repo.head.commit}")
297
+
298
+ branching_commits = repo.merge_base(upstream_main, repo.head)
299
+ for commit in branching_commits:
300
+ print(f"Branching commit: {commit}")
301
+ return get_diff(repo, repo.head.commit, branching_commits)
302
+ else:
303
+ print(f"main is at {repo.head.commit}")
304
+ parent_commits = repo.head.commit.parents
305
+ for commit in parent_commits:
306
+ print(f"Parent commit: {commit}")
307
+ return get_diff(repo, repo.head.commit, parent_commits)
308
+
309
+
310
+ def get_diff_for_doctesting(repo: Repo, base_commit: str, commits: List[str]) -> List[str]:
311
+ """
312
+ Get the diff in doc examples between a base commit and one or several commits.
313
+
314
+ Args:
315
+ repo (`git.Repo`):
316
+ A git repository (for instance the Transformers repo).
317
+ base_commit (`str`):
318
+ The commit reference of where to compare for the diff. This is the current commit, not the branching point!
319
+ commits (`List[str]`):
320
+ The list of commits with which to compare the repo at `base_commit` (so the branching point).
321
+
322
+ Returns:
323
+ `List[str]`: The list of Python and Markdown files with a diff (files added or renamed are always returned, files
324
+ modified are returned if the diff in the file is only in doctest examples).
325
+ """
326
+ print("\n### DIFF ###\n")
327
+ code_diff = []
328
+ for commit in commits:
329
+ for diff_obj in commit.diff(base_commit):
330
+ # We only consider Python files and doc files.
331
+ if not diff_obj.b_path.endswith(".py") and not diff_obj.b_path.endswith(".md"):
332
+ continue
333
+ # We always add new python/md files
334
+ if diff_obj.change_type in ["A"]:
335
+ code_diff.append(diff_obj.b_path)
336
+ # Now for modified files
337
+ elif diff_obj.change_type in ["M", "R"]:
338
+ # In case of renames, we'll look at the tests using both the old and new name.
339
+ if diff_obj.a_path != diff_obj.b_path:
340
+ code_diff.extend([diff_obj.a_path, diff_obj.b_path])
341
+ else:
342
+ # Otherwise, we check modifications contain some doc example(s).
343
+ if diff_contains_doc_examples(repo, commit, diff_obj.b_path):
344
+ code_diff.append(diff_obj.a_path)
345
+ else:
346
+ print(f"Ignoring diff in {diff_obj.b_path} as it doesn't contain any doc example.")
347
+
348
+ return code_diff
349
+
350
+
351
+ def get_all_doctest_files() -> List[str]:
352
+ """
353
+ Return the complete list of python and Markdown files on which we run doctest.
354
+
355
+ At this moment, we restrict this to only take files from `src/` or `docs/source/en/` that are not in `utils/not_doctested.txt`.
356
+
357
+ Returns:
358
+ `List[str]`: The complete list of Python and Markdown files on which we run doctest.
359
+ """
360
+ py_files = [str(x.relative_to(PATH_TO_REPO)) for x in PATH_TO_REPO.glob("**/*.py")]
361
+ md_files = [str(x.relative_to(PATH_TO_REPO)) for x in PATH_TO_REPO.glob("**/*.md")]
362
+ test_files_to_run = py_files + md_files
363
+
364
+ # only include files in `src` or `docs/source/en/`
365
+ test_files_to_run = [x for x in test_files_to_run if x.startswith(("src/", "docs/source/en/"))]
366
+ # not include init files
367
+ test_files_to_run = [x for x in test_files_to_run if not x.endswith(("__init__.py",))]
368
+
369
+ # These are files not doctested yet.
370
+ with open("utils/not_doctested.txt") as fp:
371
+ not_doctested = {x.split(" ")[0] for x in fp.read().strip().split("\n")}
372
+
373
+ # So far we don't have 100% coverage for doctest. This line will be removed once we achieve 100%.
374
+ test_files_to_run = [x for x in test_files_to_run if x not in not_doctested]
375
+
376
+ return sorted(test_files_to_run)
377
+
378
+
379
+ def get_new_doctest_files(repo, base_commit, branching_commit) -> List[str]:
380
+ """
381
+ Get the list of files that were removed from "utils/not_doctested.txt", between `base_commit` and
382
+ `branching_commit`.
383
+
384
+ Returns:
385
+ `List[str]`: List of files that were removed from "utils/not_doctested.txt".
386
+ """
387
+ for diff_obj in branching_commit.diff(base_commit):
388
+ # Ignores all but the "utils/not_doctested.txt" file.
389
+ if diff_obj.a_path != "utils/not_doctested.txt":
390
+ continue
391
+ # Loads the two versions
392
+ folder = Path(repo.working_dir)
393
+ with checkout_commit(repo, branching_commit):
394
+ with open(folder / "utils/not_doctested.txt", "r", encoding="utf-8") as f:
395
+ old_content = f.read()
396
+ with open(folder / "utils/not_doctested.txt", "r", encoding="utf-8") as f:
397
+ new_content = f.read()
398
+ # Compute the removed lines and return them
399
+ removed_content = {x.split(" ")[0] for x in old_content.split("\n")} - {
400
+ x.split(" ")[0] for x in new_content.split("\n")
401
+ }
402
+ return sorted(removed_content)
403
+ return []
404
+
405
+
406
+ def get_doctest_files(diff_with_last_commit: bool = False) -> List[str]:
407
+ """
408
+ Return a list of python and Markdown files where doc example have been modified between:
409
+
410
+ - the current head and the main branch if `diff_with_last_commit=False` (default)
411
+ - the current head and its parent commit otherwise.
412
+
413
+ Returns:
414
+ `List[str]`: The list of Python and Markdown files with a diff (files added or renamed are always returned, files
415
+ modified are returned if the diff in the file is only in doctest examples).
416
+ """
417
+ repo = Repo(PATH_TO_REPO)
418
+
419
+ test_files_to_run = [] # noqa
420
+ if not diff_with_last_commit:
421
+ upstream_main = repo.remotes.origin.refs.main
422
+ print(f"main is at {upstream_main.commit}")
423
+ print(f"Current head is at {repo.head.commit}")
424
+
425
+ branching_commits = repo.merge_base(upstream_main, repo.head)
426
+ for commit in branching_commits:
427
+ print(f"Branching commit: {commit}")
428
+ test_files_to_run = get_diff_for_doctesting(repo, repo.head.commit, branching_commits)
429
+ else:
430
+ print(f"main is at {repo.head.commit}")
431
+ parent_commits = repo.head.commit.parents
432
+ for commit in parent_commits:
433
+ print(f"Parent commit: {commit}")
434
+ test_files_to_run = get_diff_for_doctesting(repo, repo.head.commit, parent_commits)
435
+
436
+ all_test_files_to_run = get_all_doctest_files()
437
+
438
+ # Add to the test files to run any removed entry from "utils/not_doctested.txt".
439
+ new_test_files = get_new_doctest_files(repo, repo.head.commit, upstream_main.commit)
440
+ test_files_to_run = list(set(test_files_to_run + new_test_files))
441
+
442
+ # Do not run slow doctest tests on CircleCI
443
+ with open("utils/slow_documentation_tests.txt") as fp:
444
+ slow_documentation_tests = set(fp.read().strip().split("\n"))
445
+ test_files_to_run = [
446
+ x for x in test_files_to_run if x in all_test_files_to_run and x not in slow_documentation_tests
447
+ ]
448
+
449
+ # Make sure we did not end up with a test file that was removed
450
+ test_files_to_run = [f for f in test_files_to_run if (PATH_TO_REPO / f).exists()]
451
+
452
+ return sorted(test_files_to_run)
453
+
454
+
455
+ # (:?^|\n) -> Non-catching group for the beginning of the doc or a new line.
456
+ # \s*from\s+(\.+\S+)\s+import\s+([^\n]+) -> Line only contains from .xxx import yyy and we catch .xxx and yyy
457
+ # (?=\n) -> Look-ahead to a new line. We can't just put \n here or using find_all on this re will only catch every
458
+ # other import.
459
+ _re_single_line_relative_imports = re.compile(r"(?:^|\n)\s*from\s+(\.+\S+)\s+import\s+([^\n]+)(?=\n)")
460
+ # (:?^|\n) -> Non-catching group for the beginning of the doc or a new line.
461
+ # \s*from\s+(\.+\S+)\s+import\s+\(([^\)]+)\) -> Line continues with from .xxx import (yyy) and we catch .xxx and yyy
462
+ # yyy will take multiple lines otherwise there wouldn't be parenthesis.
463
+ _re_multi_line_relative_imports = re.compile(r"(?:^|\n)\s*from\s+(\.+\S+)\s+import\s+\(([^\)]+)\)")
464
+ # (:?^|\n) -> Non-catching group for the beginning of the doc or a new line.
465
+ # \s*from\s+transformers(\S*)\s+import\s+([^\n]+) -> Line only contains from transformers.xxx import yyy and we catch
466
+ # .xxx and yyy
467
+ # (?=\n) -> Look-ahead to a new line. We can't just put \n here or using find_all on this re will only catch every
468
+ # other import.
469
+ _re_single_line_direct_imports = re.compile(r"(?:^|\n)\s*from\s+diffusers(\S*)\s+import\s+([^\n]+)(?=\n)")
470
+ # (:?^|\n) -> Non-catching group for the beginning of the doc or a new line.
471
+ # \s*from\s+transformers(\S*)\s+import\s+\(([^\)]+)\) -> Line continues with from transformers.xxx import (yyy) and we
472
+ # catch .xxx and yyy. yyy will take multiple lines otherwise there wouldn't be parenthesis.
473
+ _re_multi_line_direct_imports = re.compile(r"(?:^|\n)\s*from\s+diffusers(\S*)\s+import\s+\(([^\)]+)\)")
474
+
475
+
476
+ def extract_imports(module_fname: str, cache: Dict[str, List[str]] = None) -> List[str]:
477
+ """
478
+ Get the imports a given module makes.
479
+
480
+ Args:
481
+ module_fname (`str`):
482
+ The name of the file of the module where we want to look at the imports (given relative to the root of
483
+ the repo).
484
+ cache (Dictionary `str` to `List[str]`, *optional*):
485
+ To speed up this function if it was previously called on `module_fname`, the cache of all previously
486
+ computed results.
487
+
488
+ Returns:
489
+ `List[str]`: The list of module filenames imported in the input `module_fname` (a submodule we import from that
490
+ is a subfolder will give its init file).
491
+ """
492
+ if cache is not None and module_fname in cache:
493
+ return cache[module_fname]
494
+
495
+ with open(PATH_TO_REPO / module_fname, "r", encoding="utf-8") as f:
496
+ content = f.read()
497
+
498
+ # Filter out all docstrings to not get imports in code examples. As before we need to deactivate formatting to
499
+ # keep this as escaped quotes and avoid this function failing on this file.
500
+ # fmt: off
501
+ splits = content.split('\"\"\"')
502
+ # fmt: on
503
+ content = "".join(splits[::2])
504
+
505
+ module_parts = str(module_fname).split(os.path.sep)
506
+ imported_modules = []
507
+
508
+ # Let's start with relative imports
509
+ relative_imports = _re_single_line_relative_imports.findall(content)
510
+ relative_imports = [
511
+ (mod, imp) for mod, imp in relative_imports if "# tests_ignore" not in imp and imp.strip() != "("
512
+ ]
513
+ multiline_relative_imports = _re_multi_line_relative_imports.findall(content)
514
+ relative_imports += [(mod, imp) for mod, imp in multiline_relative_imports if "# tests_ignore" not in imp]
515
+
516
+ # We need to remove parts of the module name depending on the depth of the relative imports.
517
+ for module, imports in relative_imports:
518
+ level = 0
519
+ while module.startswith("."):
520
+ module = module[1:]
521
+ level += 1
522
+
523
+ if len(module) > 0:
524
+ dep_parts = module_parts[: len(module_parts) - level] + module.split(".")
525
+ else:
526
+ dep_parts = module_parts[: len(module_parts) - level]
527
+ imported_module = os.path.sep.join(dep_parts)
528
+ imported_modules.append((imported_module, [imp.strip() for imp in imports.split(",")]))
529
+
530
+ # Let's continue with direct imports
531
+ direct_imports = _re_single_line_direct_imports.findall(content)
532
+ direct_imports = [(mod, imp) for mod, imp in direct_imports if "# tests_ignore" not in imp and imp.strip() != "("]
533
+ multiline_direct_imports = _re_multi_line_direct_imports.findall(content)
534
+ direct_imports += [(mod, imp) for mod, imp in multiline_direct_imports if "# tests_ignore" not in imp]
535
+
536
+ # We need to find the relative path of those imports.
537
+ for module, imports in direct_imports:
538
+ import_parts = module.split(".")[1:] # ignore the name of the repo since we add it below.
539
+ dep_parts = ["src", "diffusers"] + import_parts
540
+ imported_module = os.path.sep.join(dep_parts)
541
+ imported_modules.append((imported_module, [imp.strip() for imp in imports.split(",")]))
542
+
543
+ result = []
544
+ # Double check we get proper modules (either a python file or a folder with an init).
545
+ for module_file, imports in imported_modules:
546
+ if (PATH_TO_REPO / f"{module_file}.py").is_file():
547
+ module_file = f"{module_file}.py"
548
+ elif (PATH_TO_REPO / module_file).is_dir() and (PATH_TO_REPO / module_file / "__init__.py").is_file():
549
+ module_file = os.path.sep.join([module_file, "__init__.py"])
550
+ imports = [imp for imp in imports if len(imp) > 0 and re.match("^[A-Za-z0-9_]*$", imp)]
551
+ if len(imports) > 0:
552
+ result.append((module_file, imports))
553
+
554
+ if cache is not None:
555
+ cache[module_fname] = result
556
+
557
+ return result
558
+
559
+
560
+ def get_module_dependencies(module_fname: str, cache: Dict[str, List[str]] = None) -> List[str]:
561
+ """
562
+ Refines the result of `extract_imports` to remove subfolders and get a proper list of module filenames: if a file
563
+ as an import `from utils import Foo, Bar`, with `utils` being a subfolder containing many files, this will traverse
564
+ the `utils` init file to check where those dependencies come from: for instance the files utils/foo.py and utils/bar.py.
565
+
566
+ Warning: This presupposes that all intermediate inits are properly built (with imports from the respective
567
+ submodules) and work better if objects are defined in submodules and not the intermediate init (otherwise the
568
+ intermediate init is added, and inits usually have a lot of dependencies).
569
+
570
+ Args:
571
+ module_fname (`str`):
572
+ The name of the file of the module where we want to look at the imports (given relative to the root of
573
+ the repo).
574
+ cache (Dictionary `str` to `List[str]`, *optional*):
575
+ To speed up this function if it was previously called on `module_fname`, the cache of all previously
576
+ computed results.
577
+
578
+ Returns:
579
+ `List[str]`: The list of module filenames imported in the input `module_fname` (with submodule imports refined).
580
+ """
581
+ dependencies = []
582
+ imported_modules = extract_imports(module_fname, cache=cache)
583
+ # The while loop is to recursively traverse all inits we may encounter: we will add things as we go.
584
+ while len(imported_modules) > 0:
585
+ new_modules = []
586
+ for module, imports in imported_modules:
587
+ # If we end up in an __init__ we are often not actually importing from this init (except in the case where
588
+ # the object is fully defined in the __init__)
589
+ if module.endswith("__init__.py"):
590
+ # So we get the imports from that init then try to find where our objects come from.
591
+ new_imported_modules = extract_imports(module, cache=cache)
592
+ for new_module, new_imports in new_imported_modules:
593
+ if any(i in new_imports for i in imports):
594
+ if new_module not in dependencies:
595
+ new_modules.append((new_module, [i for i in new_imports if i in imports]))
596
+ imports = [i for i in imports if i not in new_imports]
597
+ if len(imports) > 0:
598
+ # If there are any objects lefts, they may be a submodule
599
+ path_to_module = PATH_TO_REPO / module.replace("__init__.py", "")
600
+ dependencies.extend(
601
+ [
602
+ os.path.join(module.replace("__init__.py", ""), f"{i}.py")
603
+ for i in imports
604
+ if (path_to_module / f"{i}.py").is_file()
605
+ ]
606
+ )
607
+ imports = [i for i in imports if not (path_to_module / f"{i}.py").is_file()]
608
+ if len(imports) > 0:
609
+ # Then if there are still objects left, they are fully defined in the init, so we keep it as a
610
+ # dependency.
611
+ dependencies.append(module)
612
+ else:
613
+ dependencies.append(module)
614
+
615
+ imported_modules = new_modules
616
+
617
+ return dependencies
618
+
619
+
620
+ def create_reverse_dependency_tree() -> List[Tuple[str, str]]:
621
+ """
622
+ Create a list of all edges (a, b) which mean that modifying a impacts b with a going over all module and test files.
623
+ """
624
+ cache = {}
625
+ all_modules = list(PATH_TO_DIFFUSERS.glob("**/*.py")) + list(PATH_TO_TESTS.glob("**/*.py"))
626
+ all_modules = [str(mod.relative_to(PATH_TO_REPO)) for mod in all_modules]
627
+ edges = [(dep, mod) for mod in all_modules for dep in get_module_dependencies(mod, cache=cache)]
628
+
629
+ return list(set(edges))
630
+
631
+
632
+ def get_tree_starting_at(module: str, edges: List[Tuple[str, str]]) -> List[Union[str, List[str]]]:
633
+ """
634
+ Returns the tree starting at a given module following all edges.
635
+
636
+ Args:
637
+ module (`str`): The module that will be the root of the subtree we want.
638
+ edges (`List[Tuple[str, str]]`): The list of all edges of the tree.
639
+
640
+ Returns:
641
+ `List[Union[str, List[str]]]`: The tree to print in the following format: [module, [list of edges
642
+ starting at module], [list of edges starting at the preceding level], ...]
643
+ """
644
+ vertices_seen = [module]
645
+ new_edges = [edge for edge in edges if edge[0] == module and edge[1] != module and "__init__.py" not in edge[1]]
646
+ tree = [module]
647
+ while len(new_edges) > 0:
648
+ tree.append(new_edges)
649
+ final_vertices = list({edge[1] for edge in new_edges})
650
+ vertices_seen.extend(final_vertices)
651
+ new_edges = [
652
+ edge
653
+ for edge in edges
654
+ if edge[0] in final_vertices and edge[1] not in vertices_seen and "__init__.py" not in edge[1]
655
+ ]
656
+
657
+ return tree
658
+
659
+
660
+ def print_tree_deps_of(module, all_edges=None):
661
+ """
662
+ Prints the tree of modules depending on a given module.
663
+
664
+ Args:
665
+ module (`str`): The module that will be the root of the subtree we want.
666
+ all_edges (`List[Tuple[str, str]]`, *optional*):
667
+ The list of all edges of the tree. Will be set to `create_reverse_dependency_tree()` if not passed.
668
+ """
669
+ if all_edges is None:
670
+ all_edges = create_reverse_dependency_tree()
671
+ tree = get_tree_starting_at(module, all_edges)
672
+
673
+ # The list of lines is a list of tuples (line_to_be_printed, module)
674
+ # Keeping the modules lets us know where to insert each new lines in the list.
675
+ lines = [(tree[0], tree[0])]
676
+ for index in range(1, len(tree)):
677
+ edges = tree[index]
678
+ start_edges = {edge[0] for edge in edges}
679
+
680
+ for start in start_edges:
681
+ end_edges = {edge[1] for edge in edges if edge[0] == start}
682
+ # We will insert all those edges just after the line showing start.
683
+ pos = 0
684
+ while lines[pos][1] != start:
685
+ pos += 1
686
+ lines = lines[: pos + 1] + [(" " * (2 * index) + end, end) for end in end_edges] + lines[pos + 1 :]
687
+
688
+ for line in lines:
689
+ # We don't print the refs that where just here to help build lines.
690
+ print(line[0])
691
+
692
+
693
+ def init_test_examples_dependencies() -> Tuple[Dict[str, List[str]], List[str]]:
694
+ """
695
+ The test examples do not import from the examples (which are just scripts, not modules) so we need some extra
696
+ care initializing the dependency map, which is the goal of this function. It initializes the dependency map for
697
+ example files by linking each example to the example test file for the example framework.
698
+
699
+ Returns:
700
+ `Tuple[Dict[str, List[str]], List[str]]`: A tuple with two elements: the initialized dependency map which is a
701
+ dict test example file to list of example files potentially tested by that test file, and the list of all
702
+ example files (to avoid recomputing it later).
703
+ """
704
+ test_example_deps = {}
705
+ all_examples = []
706
+ for framework in ["flax", "pytorch", "tensorflow"]:
707
+ test_files = list((PATH_TO_EXAMPLES / framework).glob("test_*.py"))
708
+ all_examples.extend(test_files)
709
+ # Remove the files at the root of examples/framework since they are not proper examples (they are either utils
710
+ # or example test files).
711
+ examples = [
712
+ f for f in (PATH_TO_EXAMPLES / framework).glob("**/*.py") if f.parent != PATH_TO_EXAMPLES / framework
713
+ ]
714
+ all_examples.extend(examples)
715
+ for test_file in test_files:
716
+ with open(test_file, "r", encoding="utf-8") as f:
717
+ content = f.read()
718
+ # Map all examples to the test files found in examples/framework.
719
+ test_example_deps[str(test_file.relative_to(PATH_TO_REPO))] = [
720
+ str(e.relative_to(PATH_TO_REPO)) for e in examples if e.name in content
721
+ ]
722
+ # Also map the test files to themselves.
723
+ test_example_deps[str(test_file.relative_to(PATH_TO_REPO))].append(
724
+ str(test_file.relative_to(PATH_TO_REPO))
725
+ )
726
+ return test_example_deps, all_examples
727
+
728
+
729
+ def create_reverse_dependency_map() -> Dict[str, List[str]]:
730
+ """
731
+ Create the dependency map from module/test filename to the list of modules/tests that depend on it recursively.
732
+
733
+ Returns:
734
+ `Dict[str, List[str]]`: The reverse dependency map as a dictionary mapping filenames to all the filenames
735
+ depending on it recursively. This way the tests impacted by a change in file A are the test files in the list
736
+ corresponding to key A in this result.
737
+ """
738
+ cache = {}
739
+ # Start from the example deps init.
740
+ example_deps, examples = init_test_examples_dependencies()
741
+ # Add all modules and all tests to all examples
742
+ all_modules = list(PATH_TO_DIFFUSERS.glob("**/*.py")) + list(PATH_TO_TESTS.glob("**/*.py")) + examples
743
+ all_modules = [str(mod.relative_to(PATH_TO_REPO)) for mod in all_modules]
744
+ # Compute the direct dependencies of all modules.
745
+ direct_deps = {m: get_module_dependencies(m, cache=cache) for m in all_modules}
746
+ direct_deps.update(example_deps)
747
+
748
+ # This recurses the dependencies
749
+ something_changed = True
750
+ while something_changed:
751
+ something_changed = False
752
+ for m in all_modules:
753
+ for d in direct_deps[m]:
754
+ # We stop recursing at an init (cause we always end up in the main init and we don't want to add all
755
+ # files which the main init imports)
756
+ if d.endswith("__init__.py"):
757
+ continue
758
+ if d not in direct_deps:
759
+ raise ValueError(f"KeyError:{d}. From {m}")
760
+ new_deps = set(direct_deps[d]) - set(direct_deps[m])
761
+ if len(new_deps) > 0:
762
+ direct_deps[m].extend(list(new_deps))
763
+ something_changed = True
764
+
765
+ # Finally we can build the reverse map.
766
+ reverse_map = collections.defaultdict(list)
767
+ for m in all_modules:
768
+ for d in direct_deps[m]:
769
+ reverse_map[d].append(m)
770
+
771
+ # For inits, we don't do the reverse deps but the direct deps: if modifying an init, we want to make sure we test
772
+ # all the modules impacted by that init.
773
+ for m in [f for f in all_modules if f.endswith("__init__.py")]:
774
+ direct_deps = get_module_dependencies(m, cache=cache)
775
+ deps = sum([reverse_map[d] for d in direct_deps if not d.endswith("__init__.py")], direct_deps)
776
+ reverse_map[m] = list(set(deps) - {m})
777
+
778
+ return reverse_map
779
+
780
+
781
+ def create_module_to_test_map(reverse_map: Dict[str, List[str]] = None) -> Dict[str, List[str]]:
782
+ """
783
+ Extract the tests from the reverse_dependency_map and potentially filters the model tests.
784
+
785
+ Args:
786
+ reverse_map (`Dict[str, List[str]]`, *optional*):
787
+ The reverse dependency map as created by `create_reverse_dependency_map`. Will default to the result of
788
+ that function if not provided.
789
+ filter_pipelines (`bool`, *optional*, defaults to `False`):
790
+ Whether or not to filter pipeline tests to only include core pipelines if a file impacts a lot of models.
791
+
792
+ Returns:
793
+ `Dict[str, List[str]]`: A dictionary that maps each file to the tests to execute if that file was modified.
794
+ """
795
+ if reverse_map is None:
796
+ reverse_map = create_reverse_dependency_map()
797
+
798
+ # Utility that tells us if a given file is a test (taking test examples into account)
799
+ def is_test(fname):
800
+ if fname.startswith("tests"):
801
+ return True
802
+ if fname.startswith("examples") and fname.split(os.path.sep)[-1].startswith("test"):
803
+ return True
804
+ return False
805
+
806
+ # Build the test map
807
+ test_map = {module: [f for f in deps if is_test(f)] for module, deps in reverse_map.items()}
808
+
809
+ return test_map
810
+
811
+
812
+ def check_imports_all_exist():
813
+ """
814
+ Isn't used per se by the test fetcher but might be used later as a quality check. Putting this here for now so the
815
+ code is not lost. This checks all imports in a given file do exist.
816
+ """
817
+ cache = {}
818
+ all_modules = list(PATH_TO_DIFFUSERS.glob("**/*.py")) + list(PATH_TO_TESTS.glob("**/*.py"))
819
+ all_modules = [str(mod.relative_to(PATH_TO_REPO)) for mod in all_modules]
820
+ direct_deps = {m: get_module_dependencies(m, cache=cache) for m in all_modules}
821
+
822
+ for module, deps in direct_deps.items():
823
+ for dep in deps:
824
+ if not (PATH_TO_REPO / dep).is_file():
825
+ print(f"{module} has dependency on {dep} which does not exist.")
826
+
827
+
828
+ def _print_list(l) -> str:
829
+ """
830
+ Pretty print a list of elements with one line per element and a - starting each line.
831
+ """
832
+ return "\n".join([f"- {f}" for f in l])
833
+
834
+
835
+ def update_test_map_with_core_pipelines(json_output_file: str):
836
+ print(f"\n### ADD CORE PIPELINE TESTS ###\n{_print_list(IMPORTANT_PIPELINES)}")
837
+ with open(json_output_file, "rb") as fp:
838
+ test_map = json.load(fp)
839
+
840
+ # Add core pipelines as their own test group
841
+ test_map["core_pipelines"] = " ".join(
842
+ sorted([str(PATH_TO_TESTS / f"pipelines/{pipe}") for pipe in IMPORTANT_PIPELINES])
843
+ )
844
+
845
+ # If there are no existing pipeline tests save the map
846
+ if "pipelines" not in test_map:
847
+ with open(json_output_file, "w", encoding="UTF-8") as fp:
848
+ json.dump(test_map, fp, ensure_ascii=False)
849
+
850
+ pipeline_tests = test_map.pop("pipelines")
851
+ pipeline_tests = pipeline_tests.split(" ")
852
+
853
+ # Remove core pipeline tests from the fetched pipeline tests
854
+ updated_pipeline_tests = []
855
+ for pipe in pipeline_tests:
856
+ if pipe == "tests/pipelines" or Path(pipe).parts[2] in IMPORTANT_PIPELINES:
857
+ continue
858
+ updated_pipeline_tests.append(pipe)
859
+
860
+ if len(updated_pipeline_tests) > 0:
861
+ test_map["pipelines"] = " ".join(sorted(updated_pipeline_tests))
862
+
863
+ with open(json_output_file, "w", encoding="UTF-8") as fp:
864
+ json.dump(test_map, fp, ensure_ascii=False)
865
+
866
+
867
+ def create_json_map(test_files_to_run: List[str], json_output_file: Optional[str] = None):
868
+ """
869
+ Creates a map from a list of tests to run to easily split them by category, when running parallelism of slow tests.
870
+
871
+ Args:
872
+ test_files_to_run (`List[str]`): The list of tests to run.
873
+ json_output_file (`str`): The path where to store the built json map.
874
+ """
875
+ if json_output_file is None:
876
+ return
877
+
878
+ test_map = {}
879
+ for test_file in test_files_to_run:
880
+ # `test_file` is a path to a test folder/file, starting with `tests/`. For example,
881
+ # - `tests/models/bert/test_modeling_bert.py` or `tests/models/bert`
882
+ # - `tests/trainer/test_trainer.py` or `tests/trainer`
883
+ # - `tests/test_modeling_common.py`
884
+ names = test_file.split(os.path.sep)
885
+ module = names[1]
886
+ if module in MODULES_TO_IGNORE:
887
+ continue
888
+
889
+ if len(names) > 2 or not test_file.endswith(".py"):
890
+ # test folders under `tests` or python files under them
891
+ # take the part like tokenization, `pipeline`, etc. for other test categories
892
+ key = os.path.sep.join(names[1:2])
893
+ else:
894
+ # common test files directly under `tests/`
895
+ key = "common"
896
+
897
+ if key not in test_map:
898
+ test_map[key] = []
899
+ test_map[key].append(test_file)
900
+
901
+ # sort the keys & values
902
+ keys = sorted(test_map.keys())
903
+ test_map = {k: " ".join(sorted(test_map[k])) for k in keys}
904
+
905
+ with open(json_output_file, "w", encoding="UTF-8") as fp:
906
+ json.dump(test_map, fp, ensure_ascii=False)
907
+
908
+
909
+ def infer_tests_to_run(
910
+ output_file: str,
911
+ diff_with_last_commit: bool = False,
912
+ json_output_file: Optional[str] = None,
913
+ ):
914
+ """
915
+ The main function called by the test fetcher. Determines the tests to run from the diff.
916
+
917
+ Args:
918
+ output_file (`str`):
919
+ The path where to store the summary of the test fetcher analysis. Other files will be stored in the same
920
+ folder:
921
+
922
+ - examples_test_list.txt: The list of examples tests to run.
923
+ - test_repo_utils.txt: Will indicate if the repo utils tests should be run or not.
924
+ - doctest_list.txt: The list of doctests to run.
925
+
926
+ diff_with_last_commit (`bool`, *optional*, defaults to `False`):
927
+ Whether to analyze the diff with the last commit (for use on the main branch after a PR is merged) or with
928
+ the branching point from main (for use on each PR).
929
+ filter_models (`bool`, *optional*, defaults to `True`):
930
+ Whether or not to filter the tests to core models only, when a file modified results in a lot of model
931
+ tests.
932
+ json_output_file (`str`, *optional*):
933
+ The path where to store the json file mapping categories of tests to tests to run (used for parallelism or
934
+ the slow tests).
935
+ """
936
+ modified_files = get_modified_python_files(diff_with_last_commit=diff_with_last_commit)
937
+ print(f"\n### MODIFIED FILES ###\n{_print_list(modified_files)}")
938
+ # Create the map that will give us all impacted modules.
939
+ reverse_map = create_reverse_dependency_map()
940
+ impacted_files = modified_files.copy()
941
+ for f in modified_files:
942
+ if f in reverse_map:
943
+ impacted_files.extend(reverse_map[f])
944
+
945
+ # Remove duplicates
946
+ impacted_files = sorted(set(impacted_files))
947
+ print(f"\n### IMPACTED FILES ###\n{_print_list(impacted_files)}")
948
+
949
+ # Grab the corresponding test files:
950
+ if any(x in modified_files for x in ["setup.py"]):
951
+ test_files_to_run = ["tests", "examples"]
952
+
953
+ # in order to trigger pipeline tests even if no code change at all
954
+ if "tests/utils/tiny_model_summary.json" in modified_files:
955
+ test_files_to_run = ["tests"]
956
+ any(f.split(os.path.sep)[0] == "utils" for f in modified_files)
957
+ else:
958
+ # All modified tests need to be run.
959
+ test_files_to_run = [
960
+ f for f in modified_files if f.startswith("tests") and f.split(os.path.sep)[-1].startswith("test")
961
+ ]
962
+ # Then we grab the corresponding test files.
963
+ test_map = create_module_to_test_map(reverse_map=reverse_map)
964
+ for f in modified_files:
965
+ if f in test_map:
966
+ test_files_to_run.extend(test_map[f])
967
+ test_files_to_run = sorted(set(test_files_to_run))
968
+ # Make sure we did not end up with a test file that was removed
969
+ test_files_to_run = [f for f in test_files_to_run if (PATH_TO_REPO / f).exists()]
970
+
971
+ any(f.split(os.path.sep)[0] == "utils" for f in modified_files)
972
+
973
+ examples_tests_to_run = [f for f in test_files_to_run if f.startswith("examples")]
974
+ test_files_to_run = [f for f in test_files_to_run if not f.startswith("examples")]
975
+ print(f"\n### TEST TO RUN ###\n{_print_list(test_files_to_run)}")
976
+ if len(test_files_to_run) > 0:
977
+ with open(output_file, "w", encoding="utf-8") as f:
978
+ f.write(" ".join(test_files_to_run))
979
+
980
+ # Create a map that maps test categories to test files, i.e. `models/bert` -> [...test_modeling_bert.py, ...]
981
+
982
+ # Get all test directories (and some common test files) under `tests` and `tests/models` if `test_files_to_run`
983
+ # contains `tests` (i.e. when `setup.py` is changed).
984
+ if "tests" in test_files_to_run:
985
+ test_files_to_run = get_all_tests()
986
+
987
+ create_json_map(test_files_to_run, json_output_file)
988
+
989
+ print(f"\n### EXAMPLES TEST TO RUN ###\n{_print_list(examples_tests_to_run)}")
990
+ if len(examples_tests_to_run) > 0:
991
+ # We use `all` in the case `commit_flags["test_all"]` as well as in `create_circleci_config.py` for processing
992
+ if examples_tests_to_run == ["examples"]:
993
+ examples_tests_to_run = ["all"]
994
+ example_file = Path(output_file).parent / "examples_test_list.txt"
995
+ with open(example_file, "w", encoding="utf-8") as f:
996
+ f.write(" ".join(examples_tests_to_run))
997
+
998
+
999
+ def filter_tests(output_file: str, filters: List[str]):
1000
+ """
1001
+ Reads the content of the output file and filters out all the tests in a list of given folders.
1002
+
1003
+ Args:
1004
+ output_file (`str` or `os.PathLike`): The path to the output file of the tests fetcher.
1005
+ filters (`List[str]`): A list of folders to filter.
1006
+ """
1007
+ if not os.path.isfile(output_file):
1008
+ print("No test file found.")
1009
+ return
1010
+ with open(output_file, "r", encoding="utf-8") as f:
1011
+ test_files = f.read().split(" ")
1012
+
1013
+ if len(test_files) == 0 or test_files == [""]:
1014
+ print("No tests to filter.")
1015
+ return
1016
+
1017
+ if test_files == ["tests"]:
1018
+ test_files = [os.path.join("tests", f) for f in os.listdir("tests") if f not in ["__init__.py"] + filters]
1019
+ else:
1020
+ test_files = [f for f in test_files if f.split(os.path.sep)[1] not in filters]
1021
+
1022
+ with open(output_file, "w", encoding="utf-8") as f:
1023
+ f.write(" ".join(test_files))
1024
+
1025
+
1026
+ def parse_commit_message(commit_message: str) -> Dict[str, bool]:
1027
+ """
1028
+ Parses the commit message to detect if a command is there to skip, force all or part of the CI.
1029
+
1030
+ Args:
1031
+ commit_message (`str`): The commit message of the current commit.
1032
+
1033
+ Returns:
1034
+ `Dict[str, bool]`: A dictionary of strings to bools with keys the following keys: `"skip"`,
1035
+ `"test_all_models"` and `"test_all"`.
1036
+ """
1037
+ if commit_message is None:
1038
+ return {"skip": False, "no_filter": False, "test_all": False}
1039
+
1040
+ command_search = re.search(r"\[([^\]]*)\]", commit_message)
1041
+ if command_search is not None:
1042
+ command = command_search.groups()[0]
1043
+ command = command.lower().replace("-", " ").replace("_", " ")
1044
+ skip = command in ["ci skip", "skip ci", "circleci skip", "skip circleci"]
1045
+ no_filter = set(command.split(" ")) == {"no", "filter"}
1046
+ test_all = set(command.split(" ")) == {"test", "all"}
1047
+ return {"skip": skip, "no_filter": no_filter, "test_all": test_all}
1048
+ else:
1049
+ return {"skip": False, "no_filter": False, "test_all": False}
1050
+
1051
+
1052
+ if __name__ == "__main__":
1053
+ parser = argparse.ArgumentParser()
1054
+ parser.add_argument(
1055
+ "--output_file", type=str, default="test_list.txt", help="Where to store the list of tests to run"
1056
+ )
1057
+ parser.add_argument(
1058
+ "--json_output_file",
1059
+ type=str,
1060
+ default="test_map.json",
1061
+ help="Where to store the tests to run in a dictionary format mapping test categories to test files",
1062
+ )
1063
+ parser.add_argument(
1064
+ "--diff_with_last_commit",
1065
+ action="store_true",
1066
+ help="To fetch the tests between the current commit and the last commit",
1067
+ )
1068
+ parser.add_argument(
1069
+ "--filter_tests",
1070
+ action="store_true",
1071
+ help="Will filter the pipeline/repo utils tests outside of the generated list of tests.",
1072
+ )
1073
+ parser.add_argument(
1074
+ "--print_dependencies_of",
1075
+ type=str,
1076
+ help="Will only print the tree of modules depending on the file passed.",
1077
+ default=None,
1078
+ )
1079
+ parser.add_argument(
1080
+ "--commit_message",
1081
+ type=str,
1082
+ help="The commit message (which could contain a command to force all tests or skip the CI).",
1083
+ default=None,
1084
+ )
1085
+ args = parser.parse_args()
1086
+ if args.print_dependencies_of is not None:
1087
+ print_tree_deps_of(args.print_dependencies_of)
1088
+ else:
1089
+ repo = Repo(PATH_TO_REPO)
1090
+ commit_message = repo.head.commit.message
1091
+ commit_flags = parse_commit_message(commit_message)
1092
+ if commit_flags["skip"]:
1093
+ print("Force-skipping the CI")
1094
+ quit()
1095
+ if commit_flags["no_filter"]:
1096
+ print("Running all tests fetched without filtering.")
1097
+ if commit_flags["test_all"]:
1098
+ print("Force-launching all tests")
1099
+
1100
+ diff_with_last_commit = args.diff_with_last_commit
1101
+ if not diff_with_last_commit and not repo.head.is_detached and repo.head.ref == repo.refs.main:
1102
+ print("main branch detected, fetching tests against last commit.")
1103
+ diff_with_last_commit = True
1104
+
1105
+ if not commit_flags["test_all"]:
1106
+ try:
1107
+ infer_tests_to_run(
1108
+ args.output_file,
1109
+ diff_with_last_commit=diff_with_last_commit,
1110
+ json_output_file=args.json_output_file,
1111
+ )
1112
+ filter_tests(args.output_file, ["repo_utils"])
1113
+ update_test_map_with_core_pipelines(json_output_file=args.json_output_file)
1114
+
1115
+ except Exception as e:
1116
+ print(f"\nError when trying to grab the relevant tests: {e}\n\nRunning all tests.")
1117
+ commit_flags["test_all"] = True
1118
+
1119
+ if commit_flags["test_all"]:
1120
+ with open(args.output_file, "w", encoding="utf-8") as f:
1121
+ f.write("tests")
1122
+ example_file = Path(args.output_file).parent / "examples_test_list.txt"
1123
+ with open(example_file, "w", encoding="utf-8") as f:
1124
+ f.write("all")
1125
+
1126
+ test_files_to_run = get_all_tests()
1127
+ create_json_map(test_files_to_run, args.json_output_file)
1128
+ update_test_map_with_core_pipelines(json_output_file=args.json_output_file)