Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,833 Bytes
9242df6 a6882ee 9242df6 a6882ee 9242df6 5c32955 9242df6 a6882ee 9242df6 5c32955 a6882ee 9242df6 a6882ee 74500b8 a6882ee 74500b8 9242df6 a6882ee 9242df6 a6882ee 9242df6 a6882ee 9242df6 a6882ee 74500b8 a6882ee 9242df6 a6882ee 9242df6 a6882ee 9242df6 a6882ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import os
import subprocess
import sys
import textwrap
REPO_URL = "https://github.com/thu-ml/SageAttention.git"
REPO_DIR = "SageAttention"
def run_command(command, cwd=None, env=None):
print(f"π Running command: {' '.join(command)}")
result = subprocess.run(
command,
cwd=cwd,
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True
)
if result.returncode != 0:
print(result.stdout)
raise subprocess.CalledProcessError(result.returncode, command)
def patch_setup_py(setup_py_path):
print(f"--- [SageAttention Build] Applying patches to {setup_py_path} ---")
with open(setup_py_path, 'r', encoding='utf-8') as f:
content = f.read()
original_cxx_flags = 'CXX_FLAGS = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"]'
modified_cxx_flags = 'CXX_FLAGS = ["-g", "-O3", "-std=c++17", "-DENABLE_BF16"]'
if original_cxx_flags in content:
content = content.replace(original_cxx_flags, modified_cxx_flags)
print("π§ Patch 1/1: Removed '-fopenmp' and '-lgomp' from CXX_FLAGS.")
else:
print("β οΈ Patch 1/1: CXX_FLAGS line not found as expected. It might have been changed upstream. Skipping.")
with open(setup_py_path, 'w', encoding='utf-8') as f:
f.write(content)
print("β
Patches applied successfully.")
def install_sage_attention():
print("--- [SageAttention Build] Checking environment ---")
if os.path.isdir(REPO_DIR):
print(f"β
Directory '{REPO_DIR}' already exists, assuming SageAttention is installed. Skipping build.")
return
print(f"β³ Directory '{REPO_DIR}' not found. Starting a fresh installation of SageAttention.")
try:
print(f"--- [SageAttention Build] Step 1/3: Cloning repository ---")
run_command(["git", "clone", REPO_URL])
print("β
Repository cloned successfully.")
print(f"--- [SageAttention Build] Step 2/3: Patching setup.py ---")
setup_py_path = os.path.join(REPO_DIR, "setup.py")
patch_setup_py(setup_py_path)
print(f"--- [SageAttention Build] Step 3/3: Compiling and installing ---")
build_env = os.environ.copy()
build_env.update({
"TORCH_CUDA_ARCH_LIST": "9.0",
"EXT_PARALLEL": "4",
"NVCC_APPEND_FLAGS": "--threads 8",
"MAX_JOBS": "32"
})
print("π§ Setting build environment variables:")
print(f" - TORCH_CUDA_ARCH_LIST='{build_env['TORCH_CUDA_ARCH_LIST']}'")
print(f" - EXT_PARALLEL={build_env['EXT_PARALLEL']}")
print(f" - NVCC_APPEND_FLAGS='{build_env['NVCC_APPEND_FLAGS']}'")
print(f" - MAX_JOBS={build_env['MAX_JOBS']}")
install_command = [sys.executable, "setup.py", "install"]
run_command(install_command, cwd=REPO_DIR, env=build_env)
print("π SageAttention compiled and installed successfully! ---")
except FileNotFoundError:
print("β ERROR: 'git' command not found. Please ensure Git is installed in your environment.")
sys.exit(1)
except subprocess.CalledProcessError as e:
print(f"β Command failed with return code: {e.returncode}")
print(f"β Command: {' '.join(e.cmd)}")
print("β SageAttention installation failed. Please check the logs above for details.")
sys.exit(1)
except Exception as e:
print(f"β An unknown error occurred: {e}")
sys.exit(1)
if __name__ == "__main__":
if os.path.isdir(REPO_DIR):
print(f"Note: To force a rebuild, please delete the '{REPO_DIR}' directory first.")
install_sage_attention() |