File size: 3,833 Bytes
888ef00
 
 
06ac9b8
888ef00
06ac9b8
 
 
 
888ef00
8dc5e91
888ef00
06ac9b8
888ef00
8dc5e91
 
06ac9b8
888ef00
06ac9b8
92daf2e
06ac9b8
92daf2e
888ef00
06ac9b8
 
888ef00
06ac9b8
 
888ef00
06ac9b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
888ef00
06ac9b8
 
92daf2e
06ac9b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
888ef00
06ac9b8
 
 
 
888ef00
06ac9b8
 
888ef00
 
06ac9b8
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import subprocess
import sys
import textwrap

REPO_URL = "https://github.com/thu-ml/SageAttention.git"
REPO_DIR = "SageAttention"

def run_command(command, cwd=None, env=None):
    print(f"πŸš€ Running command: {' '.join(command)}")
    result = subprocess.run(
        command,
        cwd=cwd,
        env=env,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True
    )
    
    if result.returncode != 0:
        print(result.stdout)
        raise subprocess.CalledProcessError(result.returncode, command)

def patch_setup_py(setup_py_path):
    print(f"--- [SageAttention Build] Applying patches to {setup_py_path} ---")

    with open(setup_py_path, 'r', encoding='utf-8') as f:
        content = f.read()

    original_cxx_flags = 'CXX_FLAGS = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"]'
    modified_cxx_flags = 'CXX_FLAGS = ["-g", "-O3", "-std=c++17", "-DENABLE_BF16"]'
    
    if original_cxx_flags in content:
        content = content.replace(original_cxx_flags, modified_cxx_flags)
        print("πŸ”§ Patch 1/1: Removed '-fopenmp' and '-lgomp' from CXX_FLAGS.")
    else:
        print("⚠️ Patch 1/1: CXX_FLAGS line not found as expected. It might have been changed upstream. Skipping.")

    with open(setup_py_path, 'w', encoding='utf-8') as f:
        f.write(content)
    
    print("βœ… Patches applied successfully.")


def install_sage_attention():
    print("--- [SageAttention Build] Checking environment ---")
    
    if os.path.isdir(REPO_DIR):
        print(f"βœ… Directory '{REPO_DIR}' already exists, assuming SageAttention is installed. Skipping build.")
        return

    print(f"⏳ Directory '{REPO_DIR}' not found. Starting a fresh installation of SageAttention.")
    
    try:
        print(f"--- [SageAttention Build] Step 1/3: Cloning repository ---")
        run_command(["git", "clone", REPO_URL])
        print("βœ… Repository cloned successfully.")

        print(f"--- [SageAttention Build] Step 2/3: Patching setup.py ---")
        setup_py_path = os.path.join(REPO_DIR, "setup.py")
        patch_setup_py(setup_py_path)

        print(f"--- [SageAttention Build] Step 3/3: Compiling and installing ---")
        
        build_env = os.environ.copy()
        build_env.update({
            "TORCH_CUDA_ARCH_LIST": "9.0",
            "EXT_PARALLEL": "4",
            "NVCC_APPEND_FLAGS": "--threads 8",
            "MAX_JOBS": "32"
        })
        print("πŸ”§ Setting build environment variables:")
        print(f"   - TORCH_CUDA_ARCH_LIST='{build_env['TORCH_CUDA_ARCH_LIST']}'")
        print(f"   - EXT_PARALLEL={build_env['EXT_PARALLEL']}")
        print(f"   - NVCC_APPEND_FLAGS='{build_env['NVCC_APPEND_FLAGS']}'")
        print(f"   - MAX_JOBS={build_env['MAX_JOBS']}")

        install_command = [sys.executable, "setup.py", "install"]
        
        run_command(install_command, cwd=REPO_DIR, env=build_env)

        print("πŸŽ‰ SageAttention compiled and installed successfully! ---")

    except FileNotFoundError:
        print("❌ ERROR: 'git' command not found. Please ensure Git is installed in your environment.")
        sys.exit(1)
    except subprocess.CalledProcessError as e:
        print(f"❌ Command failed with return code: {e.returncode}")
        print(f"❌ Command: {' '.join(e.cmd)}")
        print("❌ SageAttention installation failed. Please check the logs above for details.")
        sys.exit(1)
    except Exception as e:
        print(f"❌ An unknown error occurred: {e}")
        sys.exit(1)

if __name__ == "__main__":
    if os.path.isdir(REPO_DIR):
         print(f"Note: To force a rebuild, please delete the '{REPO_DIR}' directory first.")
    install_sage_attention()