File size: 3,833 Bytes
f36e497
 
 
be200d7
f36e497
be200d7
 
 
 
f36e497
 
 
be200d7
f36e497
 
 
be200d7
f36e497
be200d7
f36e497
be200d7
f36e497
 
be200d7
 
f36e497
be200d7
 
f36e497
be200d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f36e497
be200d7
 
f36e497
be200d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f36e497
be200d7
 
 
 
f36e497
be200d7
 
f36e497
 
be200d7
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import subprocess
import sys
import textwrap

REPO_URL = "https://github.com/thu-ml/SageAttention.git"
REPO_DIR = "SageAttention"

def run_command(command, cwd=None, env=None):
    print(f"πŸš€ Running command: {' '.join(command)}")
    result = subprocess.run(
        command,
        cwd=cwd,
        env=env,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True
    )
    
    if result.returncode != 0:
        print(result.stdout)
        raise subprocess.CalledProcessError(result.returncode, command)

def patch_setup_py(setup_py_path):
    print(f"--- [SageAttention Build] Applying patches to {setup_py_path} ---")

    with open(setup_py_path, 'r', encoding='utf-8') as f:
        content = f.read()

    original_cxx_flags = 'CXX_FLAGS = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"]'
    modified_cxx_flags = 'CXX_FLAGS = ["-g", "-O3", "-std=c++17", "-DENABLE_BF16"]'
    
    if original_cxx_flags in content:
        content = content.replace(original_cxx_flags, modified_cxx_flags)
        print("πŸ”§ Patch 1/1: Removed '-fopenmp' and '-lgomp' from CXX_FLAGS.")
    else:
        print("⚠️ Patch 1/1: CXX_FLAGS line not found as expected. It might have been changed upstream. Skipping.")

    with open(setup_py_path, 'w', encoding='utf-8') as f:
        f.write(content)
    
    print("βœ… Patches applied successfully.")


def install_sage_attention():
    print("--- [SageAttention Build] Checking environment ---")
    
    if os.path.isdir(REPO_DIR):
        print(f"βœ… Directory '{REPO_DIR}' already exists, assuming SageAttention is installed. Skipping build.")
        return

    print(f"⏳ Directory '{REPO_DIR}' not found. Starting a fresh installation of SageAttention.")
    
    try:
        print(f"--- [SageAttention Build] Step 1/3: Cloning repository ---")
        run_command(["git", "clone", REPO_URL])
        print("βœ… Repository cloned successfully.")

        print(f"--- [SageAttention Build] Step 2/3: Patching setup.py ---")
        setup_py_path = os.path.join(REPO_DIR, "setup.py")
        patch_setup_py(setup_py_path)

        print(f"--- [SageAttention Build] Step 3/3: Compiling and installing ---")
        
        build_env = os.environ.copy()
        build_env.update({
            "TORCH_CUDA_ARCH_LIST": "9.0",
            "EXT_PARALLEL": "4",
            "NVCC_APPEND_FLAGS": "--threads 8",
            "MAX_JOBS": "32"
        })
        print("πŸ”§ Setting build environment variables:")
        print(f"   - TORCH_CUDA_ARCH_LIST='{build_env['TORCH_CUDA_ARCH_LIST']}'")
        print(f"   - EXT_PARALLEL={build_env['EXT_PARALLEL']}")
        print(f"   - NVCC_APPEND_FLAGS='{build_env['NVCC_APPEND_FLAGS']}'")
        print(f"   - MAX_JOBS={build_env['MAX_JOBS']}")

        install_command = [sys.executable, "setup.py", "install"]
        
        run_command(install_command, cwd=REPO_DIR, env=build_env)

        print("πŸŽ‰ SageAttention compiled and installed successfully! ---")

    except FileNotFoundError:
        print("❌ ERROR: 'git' command not found. Please ensure Git is installed in your environment.")
        sys.exit(1)
    except subprocess.CalledProcessError as e:
        print(f"❌ Command failed with return code: {e.returncode}")
        print(f"❌ Command: {' '.join(e.cmd)}")
        print("❌ SageAttention installation failed. Please check the logs above for details.")
        sys.exit(1)
    except Exception as e:
        print(f"❌ An unknown error occurred: {e}")
        sys.exit(1)

if __name__ == "__main__":
    if os.path.isdir(REPO_DIR):
         print(f"Note: To force a rebuild, please delete the '{REPO_DIR}' directory first.")
    install_sage_attention()