File size: 3,833 Bytes
1875ee2
 
 
3ec6fdc
1875ee2
3ec6fdc
 
 
 
1875ee2
0f16a33
1875ee2
3ec6fdc
1875ee2
0f16a33
 
3ec6fdc
1875ee2
3ec6fdc
9de6815
3ec6fdc
9de6815
1875ee2
3ec6fdc
 
1875ee2
3ec6fdc
 
1875ee2
3ec6fdc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1875ee2
3ec6fdc
 
9de6815
3ec6fdc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1875ee2
3ec6fdc
 
 
 
1875ee2
3ec6fdc
 
1875ee2
 
3ec6fdc
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import subprocess
import sys
import textwrap

REPO_URL = "https://github.com/thu-ml/SageAttention.git"
REPO_DIR = "SageAttention"

def run_command(command, cwd=None, env=None):
    print(f"πŸš€ Running command: {' '.join(command)}")
    result = subprocess.run(
        command,
        cwd=cwd,
        env=env,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True
    )
    
    if result.returncode != 0:
        print(result.stdout)
        raise subprocess.CalledProcessError(result.returncode, command)

def patch_setup_py(setup_py_path):
    print(f"--- [SageAttention Build] Applying patches to {setup_py_path} ---")

    with open(setup_py_path, 'r', encoding='utf-8') as f:
        content = f.read()

    original_cxx_flags = 'CXX_FLAGS = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"]'
    modified_cxx_flags = 'CXX_FLAGS = ["-g", "-O3", "-std=c++17", "-DENABLE_BF16"]'
    
    if original_cxx_flags in content:
        content = content.replace(original_cxx_flags, modified_cxx_flags)
        print("πŸ”§ Patch 1/1: Removed '-fopenmp' and '-lgomp' from CXX_FLAGS.")
    else:
        print("⚠️ Patch 1/1: CXX_FLAGS line not found as expected. It might have been changed upstream. Skipping.")

    with open(setup_py_path, 'w', encoding='utf-8') as f:
        f.write(content)
    
    print("βœ… Patches applied successfully.")


def install_sage_attention():
    print("--- [SageAttention Build] Checking environment ---")
    
    if os.path.isdir(REPO_DIR):
        print(f"βœ… Directory '{REPO_DIR}' already exists, assuming SageAttention is installed. Skipping build.")
        return

    print(f"⏳ Directory '{REPO_DIR}' not found. Starting a fresh installation of SageAttention.")
    
    try:
        print(f"--- [SageAttention Build] Step 1/3: Cloning repository ---")
        run_command(["git", "clone", REPO_URL])
        print("βœ… Repository cloned successfully.")

        print(f"--- [SageAttention Build] Step 2/3: Patching setup.py ---")
        setup_py_path = os.path.join(REPO_DIR, "setup.py")
        patch_setup_py(setup_py_path)

        print(f"--- [SageAttention Build] Step 3/3: Compiling and installing ---")
        
        build_env = os.environ.copy()
        build_env.update({
            "TORCH_CUDA_ARCH_LIST": "9.0",
            "EXT_PARALLEL": "4",
            "NVCC_APPEND_FLAGS": "--threads 8",
            "MAX_JOBS": "32"
        })
        print("πŸ”§ Setting build environment variables:")
        print(f"   - TORCH_CUDA_ARCH_LIST='{build_env['TORCH_CUDA_ARCH_LIST']}'")
        print(f"   - EXT_PARALLEL={build_env['EXT_PARALLEL']}")
        print(f"   - NVCC_APPEND_FLAGS='{build_env['NVCC_APPEND_FLAGS']}'")
        print(f"   - MAX_JOBS={build_env['MAX_JOBS']}")

        install_command = [sys.executable, "setup.py", "install"]
        
        run_command(install_command, cwd=REPO_DIR, env=build_env)

        print("πŸŽ‰ SageAttention compiled and installed successfully! ---")

    except FileNotFoundError:
        print("❌ ERROR: 'git' command not found. Please ensure Git is installed in your environment.")
        sys.exit(1)
    except subprocess.CalledProcessError as e:
        print(f"❌ Command failed with return code: {e.returncode}")
        print(f"❌ Command: {' '.join(e.cmd)}")
        print("❌ SageAttention installation failed. Please check the logs above for details.")
        sys.exit(1)
    except Exception as e:
        print(f"❌ An unknown error occurred: {e}")
        sys.exit(1)

if __name__ == "__main__":
    if os.path.isdir(REPO_DIR):
         print(f"Note: To force a rebuild, please delete the '{REPO_DIR}' directory first.")
    install_sage_attention()