Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import subprocess | |
| import sys | |
| import textwrap | |
| REPO_URL = "https://github.com/thu-ml/SageAttention.git" | |
| REPO_DIR = "SageAttention" | |
| def run_command(command, cwd=None, env=None): | |
| print(f"π Running command: {' '.join(command)}") | |
| result = subprocess.run( | |
| command, | |
| cwd=cwd, | |
| env=env, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.STDOUT, | |
| text=True | |
| ) | |
| if result.returncode != 0: | |
| print(result.stdout) | |
| raise subprocess.CalledProcessError(result.returncode, command) | |
| def patch_setup_py(setup_py_path): | |
| print(f"--- [SageAttention Build] Applying patches to {setup_py_path} ---") | |
| with open(setup_py_path, 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| original_cxx_flags = 'CXX_FLAGS = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"]' | |
| modified_cxx_flags = 'CXX_FLAGS = ["-g", "-O3", "-std=c++17", "-DENABLE_BF16"]' | |
| if original_cxx_flags in content: | |
| content = content.replace(original_cxx_flags, modified_cxx_flags) | |
| print("π§ Patch 1/1: Removed '-fopenmp' and '-lgomp' from CXX_FLAGS.") | |
| else: | |
| print("β οΈ Patch 1/1: CXX_FLAGS line not found as expected. It might have been changed upstream. Skipping.") | |
| with open(setup_py_path, 'w', encoding='utf-8') as f: | |
| f.write(content) | |
| print("β Patches applied successfully.") | |
| def install_sage_attention(): | |
| print("--- [SageAttention Build] Checking environment ---") | |
| if os.path.isdir(REPO_DIR): | |
| print(f"β Directory '{REPO_DIR}' already exists, assuming SageAttention is installed. Skipping build.") | |
| return | |
| print(f"β³ Directory '{REPO_DIR}' not found. Starting a fresh installation of SageAttention.") | |
| try: | |
| print(f"--- [SageAttention Build] Step 1/3: Cloning repository ---") | |
| run_command(["git", "clone", REPO_URL]) | |
| print("β Repository cloned successfully.") | |
| print(f"--- [SageAttention Build] Step 2/3: Patching setup.py ---") | |
| setup_py_path = os.path.join(REPO_DIR, "setup.py") | |
| patch_setup_py(setup_py_path) | |
| print(f"--- [SageAttention Build] Step 3/3: Compiling and installing ---") | |
| build_env = os.environ.copy() | |
| build_env.update({ | |
| "TORCH_CUDA_ARCH_LIST": "9.0", | |
| "EXT_PARALLEL": "4", | |
| "NVCC_APPEND_FLAGS": "--threads 8", | |
| "MAX_JOBS": "32" | |
| }) | |
| print("π§ Setting build environment variables:") | |
| print(f" - TORCH_CUDA_ARCH_LIST='{build_env['TORCH_CUDA_ARCH_LIST']}'") | |
| print(f" - EXT_PARALLEL={build_env['EXT_PARALLEL']}") | |
| print(f" - NVCC_APPEND_FLAGS='{build_env['NVCC_APPEND_FLAGS']}'") | |
| print(f" - MAX_JOBS={build_env['MAX_JOBS']}") | |
| install_command = [sys.executable, "setup.py", "install"] | |
| run_command(install_command, cwd=REPO_DIR, env=build_env) | |
| print("π SageAttention compiled and installed successfully! ---") | |
| except FileNotFoundError: | |
| print("β ERROR: 'git' command not found. Please ensure Git is installed in your environment.") | |
| sys.exit(1) | |
| except subprocess.CalledProcessError as e: | |
| print(f"β Command failed with return code: {e.returncode}") | |
| print(f"β Command: {' '.join(e.cmd)}") | |
| print("β SageAttention installation failed. Please check the logs above for details.") | |
| sys.exit(1) | |
| except Exception as e: | |
| print(f"β An unknown error occurred: {e}") | |
| sys.exit(1) | |
| if __name__ == "__main__": | |
| if os.path.isdir(REPO_DIR): | |
| print(f"Note: To force a rebuild, please delete the '{REPO_DIR}' directory first.") | |
| install_sage_attention() |