RioShiina commited on
Commit
a6882ee
Β·
verified Β·
1 Parent(s): 03e71e5

Upload folder using huggingface_hub

Browse files
comfy_integration/setup.py CHANGED
@@ -20,7 +20,7 @@ def initialize_comfyui():
20
 
21
  print("--- Cloning ComfyUI Repository ---")
22
  if not os.path.exists(COMFYUI_TEMP_DIR):
23
- os.system(f"git clone https://github.com/comfyanonymous/ComfyUI {COMFYUI_TEMP_DIR}")
24
  print("βœ… ComfyUI repository cloned.")
25
  else:
26
  print("βœ… ComfyUI repository already exists.")
 
20
 
21
  print("--- Cloning ComfyUI Repository ---")
22
  if not os.path.exists(COMFYUI_TEMP_DIR):
23
+ os.system(f"git clone https://github.com/comfy-Org/ComfyUI {COMFYUI_TEMP_DIR}")
24
  print("βœ… ComfyUI repository cloned.")
25
  else:
26
  print("βœ… ComfyUI repository already exists.")
scripts/build_sage_attention.py CHANGED
@@ -1,66 +1,99 @@
1
  import os
2
  import subprocess
3
  import sys
 
4
 
5
- def run_command(command, env=None):
6
- """
7
- Runs a command with a specified environment, prints its output,
8
- and raises an exception on failure.
9
- """
10
  print(f"πŸš€ Running command: {' '.join(command)}")
11
  result = subprocess.run(
12
  command,
 
13
  env=env,
14
  stdout=subprocess.PIPE,
15
  stderr=subprocess.STDOUT,
16
- text=True,
17
- encoding='utf-8',
18
- errors='replace'
19
  )
20
-
21
- if result.stdout:
22
- print("--- Pip Output ---")
23
- print(result.stdout.strip())
24
- print("------------------")
25
-
26
  if result.returncode != 0:
 
27
  raise subprocess.CalledProcessError(result.returncode, command)
28
 
29
- def install_sage_attention():
30
- """
31
- Installs the sageattention package from PyPI using pip, ensuring the
32
- correct CUDA architecture is set for any potential on-the-fly compilation.
33
- """
34
- print("--- [SageAttention Install] Starting installation using pip ---")
35
 
36
- build_env = os.environ.copy()
37
- build_env["TORCH_CUDA_ARCH_LIST"] = "9.0"
38
- print(f"πŸ”§ Setting build environment variable: TORCH_CUDA_ARCH_LIST='{build_env['TORCH_CUDA_ARCH_LIST']}'")
39
 
40
- install_command = [
41
- sys.executable,
42
- "-m",
43
- "pip",
44
- "install",
45
- "sageattention==2.2.0",
46
- "--no-build-isolation",
47
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
 
 
49
  try:
50
- run_command(install_command, env=build_env)
51
- print("πŸŽ‰ SageAttention installed successfully via pip!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  except subprocess.CalledProcessError as e:
53
- print(f"❌ SageAttention installation via pip failed (Exit Code: {e.returncode}).")
54
- print(" The application will continue with the default attention mechanism.")
55
- raise e
 
56
  except Exception as e:
57
- print(f"❌ An unexpected error occurred during pip installation: {e}")
58
- raise e
59
-
60
 
61
  if __name__ == "__main__":
62
- try:
63
- install_sage_attention()
64
- except Exception:
65
- print("\nInstallation script finished with an error.")
66
- sys.exit(1)
 
1
  import os
2
  import subprocess
3
  import sys
4
+ import textwrap
5
 
6
+ REPO_URL = "https://github.com/thu-ml/SageAttention.git"
7
+ REPO_DIR = "SageAttention"
8
+
9
+ def run_command(command, cwd=None, env=None):
 
10
  print(f"πŸš€ Running command: {' '.join(command)}")
11
  result = subprocess.run(
12
  command,
13
+ cwd=cwd,
14
  env=env,
15
  stdout=subprocess.PIPE,
16
  stderr=subprocess.STDOUT,
17
+ text=True
 
 
18
  )
19
+
 
 
 
 
 
20
  if result.returncode != 0:
21
+ print(result.stdout)
22
  raise subprocess.CalledProcessError(result.returncode, command)
23
 
24
+ def patch_setup_py(setup_py_path):
25
+ print(f"--- [SageAttention Build] Applying patches to {setup_py_path} ---")
 
 
 
 
26
 
27
+ with open(setup_py_path, 'r', encoding='utf-8') as f:
28
+ content = f.read()
 
29
 
30
+ original_cxx_flags = 'CXX_FLAGS = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"]'
31
+ modified_cxx_flags = 'CXX_FLAGS = ["-g", "-O3", "-std=c++17", "-DENABLE_BF16"]'
32
+
33
+ if original_cxx_flags in content:
34
+ content = content.replace(original_cxx_flags, modified_cxx_flags)
35
+ print("πŸ”§ Patch 1/1: Removed '-fopenmp' and '-lgomp' from CXX_FLAGS.")
36
+ else:
37
+ print("⚠️ Patch 1/1: CXX_FLAGS line not found as expected. It might have been changed upstream. Skipping.")
38
+
39
+ with open(setup_py_path, 'w', encoding='utf-8') as f:
40
+ f.write(content)
41
+
42
+ print("βœ… Patches applied successfully.")
43
+
44
+
45
+ def install_sage_attention():
46
+ print("--- [SageAttention Build] Checking environment ---")
47
+
48
+ if os.path.isdir(REPO_DIR):
49
+ print(f"βœ… Directory '{REPO_DIR}' already exists, assuming SageAttention is installed. Skipping build.")
50
+ return
51
 
52
+ print(f"⏳ Directory '{REPO_DIR}' not found. Starting a fresh installation of SageAttention.")
53
+
54
  try:
55
+ print(f"--- [SageAttention Build] Step 1/3: Cloning repository ---")
56
+ run_command(["git", "clone", REPO_URL])
57
+ print("βœ… Repository cloned successfully.")
58
+
59
+ print(f"--- [SageAttention Build] Step 2/3: Patching setup.py ---")
60
+ setup_py_path = os.path.join(REPO_DIR, "setup.py")
61
+ patch_setup_py(setup_py_path)
62
+
63
+ print(f"--- [SageAttention Build] Step 3/3: Compiling and installing ---")
64
+
65
+ build_env = os.environ.copy()
66
+ build_env.update({
67
+ "TORCH_CUDA_ARCH_LIST": "9.0",
68
+ "EXT_PARALLEL": "4",
69
+ "NVCC_APPEND_FLAGS": "--threads 8",
70
+ "MAX_JOBS": "32"
71
+ })
72
+ print("πŸ”§ Setting build environment variables:")
73
+ print(f" - TORCH_CUDA_ARCH_LIST='{build_env['TORCH_CUDA_ARCH_LIST']}'")
74
+ print(f" - EXT_PARALLEL={build_env['EXT_PARALLEL']}")
75
+ print(f" - NVCC_APPEND_FLAGS='{build_env['NVCC_APPEND_FLAGS']}'")
76
+ print(f" - MAX_JOBS={build_env['MAX_JOBS']}")
77
+
78
+ install_command = [sys.executable, "setup.py", "install"]
79
+
80
+ run_command(install_command, cwd=REPO_DIR, env=build_env)
81
+
82
+ print("πŸŽ‰ SageAttention compiled and installed successfully! ---")
83
+
84
+ except FileNotFoundError:
85
+ print("❌ ERROR: 'git' command not found. Please ensure Git is installed in your environment.")
86
+ sys.exit(1)
87
  except subprocess.CalledProcessError as e:
88
+ print(f"❌ Command failed with return code: {e.returncode}")
89
+ print(f"❌ Command: {' '.join(e.cmd)}")
90
+ print("❌ SageAttention installation failed. Please check the logs above for details.")
91
+ sys.exit(1)
92
  except Exception as e:
93
+ print(f"❌ An unknown error occurred: {e}")
94
+ sys.exit(1)
 
95
 
96
  if __name__ == "__main__":
97
+ if os.path.isdir(REPO_DIR):
98
+ print(f"Note: To force a rebuild, please delete the '{REPO_DIR}' directory first.")
99
+ install_sage_attention()