File size: 3,833 Bytes
5b29993
 
 
de5d312
5b29993
de5d312
 
 
 
5b29993
 
 
de5d312
5b29993
 
 
de5d312
5b29993
de5d312
f465e8b
de5d312
f465e8b
5b29993
de5d312
 
5b29993
de5d312
 
5b29993
de5d312
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b29993
de5d312
 
f465e8b
de5d312
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b29993
de5d312
 
 
 
5b29993
de5d312
 
5b29993
 
de5d312
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import subprocess
import sys
import textwrap

REPO_URL = "https://github.com/thu-ml/SageAttention.git"
REPO_DIR = "SageAttention"

def run_command(command, cwd=None, env=None):
    print(f"πŸš€ Running command: {' '.join(command)}")
    result = subprocess.run(
        command,
        cwd=cwd,
        env=env,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True
    )
    
    if result.returncode != 0:
        print(result.stdout)
        raise subprocess.CalledProcessError(result.returncode, command)

def patch_setup_py(setup_py_path):
    print(f"--- [SageAttention Build] Applying patches to {setup_py_path} ---")

    with open(setup_py_path, 'r', encoding='utf-8') as f:
        content = f.read()

    original_cxx_flags = 'CXX_FLAGS = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"]'
    modified_cxx_flags = 'CXX_FLAGS = ["-g", "-O3", "-std=c++17", "-DENABLE_BF16"]'
    
    if original_cxx_flags in content:
        content = content.replace(original_cxx_flags, modified_cxx_flags)
        print("πŸ”§ Patch 1/1: Removed '-fopenmp' and '-lgomp' from CXX_FLAGS.")
    else:
        print("⚠️ Patch 1/1: CXX_FLAGS line not found as expected. It might have been changed upstream. Skipping.")

    with open(setup_py_path, 'w', encoding='utf-8') as f:
        f.write(content)
    
    print("βœ… Patches applied successfully.")


def install_sage_attention():
    print("--- [SageAttention Build] Checking environment ---")
    
    if os.path.isdir(REPO_DIR):
        print(f"βœ… Directory '{REPO_DIR}' already exists, assuming SageAttention is installed. Skipping build.")
        return

    print(f"⏳ Directory '{REPO_DIR}' not found. Starting a fresh installation of SageAttention.")
    
    try:
        print(f"--- [SageAttention Build] Step 1/3: Cloning repository ---")
        run_command(["git", "clone", REPO_URL])
        print("βœ… Repository cloned successfully.")

        print(f"--- [SageAttention Build] Step 2/3: Patching setup.py ---")
        setup_py_path = os.path.join(REPO_DIR, "setup.py")
        patch_setup_py(setup_py_path)

        print(f"--- [SageAttention Build] Step 3/3: Compiling and installing ---")
        
        build_env = os.environ.copy()
        build_env.update({
            "TORCH_CUDA_ARCH_LIST": "9.0",
            "EXT_PARALLEL": "4",
            "NVCC_APPEND_FLAGS": "--threads 8",
            "MAX_JOBS": "32"
        })
        print("πŸ”§ Setting build environment variables:")
        print(f"   - TORCH_CUDA_ARCH_LIST='{build_env['TORCH_CUDA_ARCH_LIST']}'")
        print(f"   - EXT_PARALLEL={build_env['EXT_PARALLEL']}")
        print(f"   - NVCC_APPEND_FLAGS='{build_env['NVCC_APPEND_FLAGS']}'")
        print(f"   - MAX_JOBS={build_env['MAX_JOBS']}")

        install_command = [sys.executable, "setup.py", "install"]
        
        run_command(install_command, cwd=REPO_DIR, env=build_env)

        print("πŸŽ‰ SageAttention compiled and installed successfully! ---")

    except FileNotFoundError:
        print("❌ ERROR: 'git' command not found. Please ensure Git is installed in your environment.")
        sys.exit(1)
    except subprocess.CalledProcessError as e:
        print(f"❌ Command failed with return code: {e.returncode}")
        print(f"❌ Command: {' '.join(e.cmd)}")
        print("❌ SageAttention installation failed. Please check the logs above for details.")
        sys.exit(1)
    except Exception as e:
        print(f"❌ An unknown error occurred: {e}")
        sys.exit(1)

if __name__ == "__main__":
    if os.path.isdir(REPO_DIR):
         print(f"Note: To force a rebuild, please delete the '{REPO_DIR}' directory first.")
    install_sage_attention()