import os import subprocess import sys import textwrap REPO_URL = "https://github.com/thu-ml/SageAttention.git" REPO_DIR = "SageAttention" def run_command(command, cwd=None, env=None): print(f"🚀 Running command: {' '.join(command)}") result = subprocess.run( command, cwd=cwd, env=env, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True ) if result.returncode != 0: print(result.stdout) raise subprocess.CalledProcessError(result.returncode, command) def patch_setup_py(setup_py_path): print(f"--- [SageAttention Build] Applying patches to {setup_py_path} ---") with open(setup_py_path, 'r', encoding='utf-8') as f: content = f.read() original_cxx_flags = 'CXX_FLAGS = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"]' modified_cxx_flags = 'CXX_FLAGS = ["-g", "-O3", "-std=c++17", "-DENABLE_BF16"]' if original_cxx_flags in content: content = content.replace(original_cxx_flags, modified_cxx_flags) print("🔧 Patch 1/1: Removed '-fopenmp' and '-lgomp' from CXX_FLAGS.") else: print("⚠️ Patch 1/1: CXX_FLAGS line not found as expected. It might have been changed upstream. Skipping.") with open(setup_py_path, 'w', encoding='utf-8') as f: f.write(content) print("✅ Patches applied successfully.") def install_sage_attention(): print("--- [SageAttention Build] Checking environment ---") if os.path.isdir(REPO_DIR): print(f"✅ Directory '{REPO_DIR}' already exists, assuming SageAttention is installed. Skipping build.") return print(f"⏳ Directory '{REPO_DIR}' not found. Starting a fresh installation of SageAttention.") try: print(f"--- [SageAttention Build] Step 1/3: Cloning repository ---") run_command(["git", "clone", REPO_URL]) print("✅ Repository cloned successfully.") print(f"--- [SageAttention Build] Step 2/3: Patching setup.py ---") setup_py_path = os.path.join(REPO_DIR, "setup.py") patch_setup_py(setup_py_path) print(f"--- [SageAttention Build] Step 3/3: Compiling and installing ---") build_env = os.environ.copy() build_env.update({ "TORCH_CUDA_ARCH_LIST": "9.0", "EXT_PARALLEL": "4", "NVCC_APPEND_FLAGS": "--threads 8", "MAX_JOBS": "32" }) print("🔧 Setting build environment variables:") print(f" - TORCH_CUDA_ARCH_LIST='{build_env['TORCH_CUDA_ARCH_LIST']}'") print(f" - EXT_PARALLEL={build_env['EXT_PARALLEL']}") print(f" - NVCC_APPEND_FLAGS='{build_env['NVCC_APPEND_FLAGS']}'") print(f" - MAX_JOBS={build_env['MAX_JOBS']}") install_command = [sys.executable, "setup.py", "install"] run_command(install_command, cwd=REPO_DIR, env=build_env) print("🎉 SageAttention compiled and installed successfully! ---") except FileNotFoundError: print("❌ ERROR: 'git' command not found. Please ensure Git is installed in your environment.") sys.exit(1) except subprocess.CalledProcessError as e: print(f"❌ Command failed with return code: {e.returncode}") print(f"❌ Command: {' '.join(e.cmd)}") print("❌ SageAttention installation failed. Please check the logs above for details.") sys.exit(1) except Exception as e: print(f"❌ An unknown error occurred: {e}") sys.exit(1) if __name__ == "__main__": if os.path.isdir(REPO_DIR): print(f"Note: To force a rebuild, please delete the '{REPO_DIR}' directory first.") install_sage_attention()