Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,833 Bytes
f36e497 be200d7 f36e497 be200d7 f36e497 be200d7 f36e497 be200d7 f36e497 be200d7 f36e497 be200d7 f36e497 be200d7 f36e497 be200d7 f36e497 be200d7 f36e497 be200d7 f36e497 be200d7 f36e497 be200d7 f36e497 be200d7 f36e497 be200d7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import os
import subprocess
import sys
import textwrap
REPO_URL = "https://github.com/thu-ml/SageAttention.git"
REPO_DIR = "SageAttention"
def run_command(command, cwd=None, env=None):
print(f"π Running command: {' '.join(command)}")
result = subprocess.run(
command,
cwd=cwd,
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True
)
if result.returncode != 0:
print(result.stdout)
raise subprocess.CalledProcessError(result.returncode, command)
def patch_setup_py(setup_py_path):
print(f"--- [SageAttention Build] Applying patches to {setup_py_path} ---")
with open(setup_py_path, 'r', encoding='utf-8') as f:
content = f.read()
original_cxx_flags = 'CXX_FLAGS = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"]'
modified_cxx_flags = 'CXX_FLAGS = ["-g", "-O3", "-std=c++17", "-DENABLE_BF16"]'
if original_cxx_flags in content:
content = content.replace(original_cxx_flags, modified_cxx_flags)
print("π§ Patch 1/1: Removed '-fopenmp' and '-lgomp' from CXX_FLAGS.")
else:
print("β οΈ Patch 1/1: CXX_FLAGS line not found as expected. It might have been changed upstream. Skipping.")
with open(setup_py_path, 'w', encoding='utf-8') as f:
f.write(content)
print("β
Patches applied successfully.")
def install_sage_attention():
print("--- [SageAttention Build] Checking environment ---")
if os.path.isdir(REPO_DIR):
print(f"β
Directory '{REPO_DIR}' already exists, assuming SageAttention is installed. Skipping build.")
return
print(f"β³ Directory '{REPO_DIR}' not found. Starting a fresh installation of SageAttention.")
try:
print(f"--- [SageAttention Build] Step 1/3: Cloning repository ---")
run_command(["git", "clone", REPO_URL])
print("β
Repository cloned successfully.")
print(f"--- [SageAttention Build] Step 2/3: Patching setup.py ---")
setup_py_path = os.path.join(REPO_DIR, "setup.py")
patch_setup_py(setup_py_path)
print(f"--- [SageAttention Build] Step 3/3: Compiling and installing ---")
build_env = os.environ.copy()
build_env.update({
"TORCH_CUDA_ARCH_LIST": "9.0",
"EXT_PARALLEL": "4",
"NVCC_APPEND_FLAGS": "--threads 8",
"MAX_JOBS": "32"
})
print("π§ Setting build environment variables:")
print(f" - TORCH_CUDA_ARCH_LIST='{build_env['TORCH_CUDA_ARCH_LIST']}'")
print(f" - EXT_PARALLEL={build_env['EXT_PARALLEL']}")
print(f" - NVCC_APPEND_FLAGS='{build_env['NVCC_APPEND_FLAGS']}'")
print(f" - MAX_JOBS={build_env['MAX_JOBS']}")
install_command = [sys.executable, "setup.py", "install"]
run_command(install_command, cwd=REPO_DIR, env=build_env)
print("π SageAttention compiled and installed successfully! ---")
except FileNotFoundError:
print("β ERROR: 'git' command not found. Please ensure Git is installed in your environment.")
sys.exit(1)
except subprocess.CalledProcessError as e:
print(f"β Command failed with return code: {e.returncode}")
print(f"β Command: {' '.join(e.cmd)}")
print("β SageAttention installation failed. Please check the logs above for details.")
sys.exit(1)
except Exception as e:
print(f"β An unknown error occurred: {e}")
sys.exit(1)
if __name__ == "__main__":
if os.path.isdir(REPO_DIR):
print(f"Note: To force a rebuild, please delete the '{REPO_DIR}' directory first.")
install_sage_attention() |