File size: 3,833 Bytes
9242df6
 
 
a6882ee
9242df6
a6882ee
 
 
 
9242df6
5c32955
9242df6
a6882ee
9242df6
5c32955
 
a6882ee
9242df6
a6882ee
74500b8
a6882ee
74500b8
9242df6
a6882ee
 
9242df6
a6882ee
 
9242df6
a6882ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9242df6
a6882ee
 
74500b8
a6882ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9242df6
a6882ee
 
 
 
9242df6
a6882ee
 
9242df6
 
a6882ee
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import subprocess
import sys
import textwrap

REPO_URL = "https://github.com/thu-ml/SageAttention.git"
REPO_DIR = "SageAttention"

def run_command(command, cwd=None, env=None):
    print(f"πŸš€ Running command: {' '.join(command)}")
    result = subprocess.run(
        command,
        cwd=cwd,
        env=env,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True
    )
    
    if result.returncode != 0:
        print(result.stdout)
        raise subprocess.CalledProcessError(result.returncode, command)

def patch_setup_py(setup_py_path):
    print(f"--- [SageAttention Build] Applying patches to {setup_py_path} ---")

    with open(setup_py_path, 'r', encoding='utf-8') as f:
        content = f.read()

    original_cxx_flags = 'CXX_FLAGS = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"]'
    modified_cxx_flags = 'CXX_FLAGS = ["-g", "-O3", "-std=c++17", "-DENABLE_BF16"]'
    
    if original_cxx_flags in content:
        content = content.replace(original_cxx_flags, modified_cxx_flags)
        print("πŸ”§ Patch 1/1: Removed '-fopenmp' and '-lgomp' from CXX_FLAGS.")
    else:
        print("⚠️ Patch 1/1: CXX_FLAGS line not found as expected. It might have been changed upstream. Skipping.")

    with open(setup_py_path, 'w', encoding='utf-8') as f:
        f.write(content)
    
    print("βœ… Patches applied successfully.")


def install_sage_attention():
    print("--- [SageAttention Build] Checking environment ---")
    
    if os.path.isdir(REPO_DIR):
        print(f"βœ… Directory '{REPO_DIR}' already exists, assuming SageAttention is installed. Skipping build.")
        return

    print(f"⏳ Directory '{REPO_DIR}' not found. Starting a fresh installation of SageAttention.")
    
    try:
        print(f"--- [SageAttention Build] Step 1/3: Cloning repository ---")
        run_command(["git", "clone", REPO_URL])
        print("βœ… Repository cloned successfully.")

        print(f"--- [SageAttention Build] Step 2/3: Patching setup.py ---")
        setup_py_path = os.path.join(REPO_DIR, "setup.py")
        patch_setup_py(setup_py_path)

        print(f"--- [SageAttention Build] Step 3/3: Compiling and installing ---")
        
        build_env = os.environ.copy()
        build_env.update({
            "TORCH_CUDA_ARCH_LIST": "9.0",
            "EXT_PARALLEL": "4",
            "NVCC_APPEND_FLAGS": "--threads 8",
            "MAX_JOBS": "32"
        })
        print("πŸ”§ Setting build environment variables:")
        print(f"   - TORCH_CUDA_ARCH_LIST='{build_env['TORCH_CUDA_ARCH_LIST']}'")
        print(f"   - EXT_PARALLEL={build_env['EXT_PARALLEL']}")
        print(f"   - NVCC_APPEND_FLAGS='{build_env['NVCC_APPEND_FLAGS']}'")
        print(f"   - MAX_JOBS={build_env['MAX_JOBS']}")

        install_command = [sys.executable, "setup.py", "install"]
        
        run_command(install_command, cwd=REPO_DIR, env=build_env)

        print("πŸŽ‰ SageAttention compiled and installed successfully! ---")

    except FileNotFoundError:
        print("❌ ERROR: 'git' command not found. Please ensure Git is installed in your environment.")
        sys.exit(1)
    except subprocess.CalledProcessError as e:
        print(f"❌ Command failed with return code: {e.returncode}")
        print(f"❌ Command: {' '.join(e.cmd)}")
        print("❌ SageAttention installation failed. Please check the logs above for details.")
        sys.exit(1)
    except Exception as e:
        print(f"❌ An unknown error occurred: {e}")
        sys.exit(1)

if __name__ == "__main__":
    if os.path.isdir(REPO_DIR):
         print(f"Note: To force a rebuild, please delete the '{REPO_DIR}' directory first.")
    install_sage_attention()