#!/usr/bin/env bash # Debug script 4: MegaBlocks-specific build debugging set -euo pipefail echo "=== MegaBlocks Build Debug Script 4 ===" echo "Testing MegaBlocks-specific compilation components" echo # Set ROCm environment variables export ROCM_PATH="${ROCM_PATH:-/opt/rocm-7.0.1}" export ROCM_HOME="${ROCM_HOME:-$ROCM_PATH}" export HIP_PATH="${HIP_PATH:-$ROCM_PATH}" export HIP_HOME="${HIP_HOME:-$ROCM_PATH}" export PATH="$ROCM_HOME/bin:$PATH" export TORCH_HIP_ARCH_LIST="${TORCH_HIP_ARCH_LIST:-gfx942}" export HSA_OVERRIDE_GFX_VERSION="${HSA_OVERRIDE_GFX_VERSION:-gfx942}" export TORCH_EXTENSIONS_DIR="${TORCH_EXTENSIONS_DIR:-$PWD/.torch_extensions_debug}" echo "Working directory: $(pwd)" echo echo "=== Checking MegaBlocks Source Files ===" echo "Verifying all source files exist:" sources=( "torch-ext/torch_binding.cpp" "csrc/new_cumsum.cu" "csrc/new_histogram.cu" "csrc/new_indices.cu" "csrc/new_replicate.cu" "csrc/new_sort.cu" "csrc/grouped_gemm/grouped_gemm.cu" ) all_exist=true for src in "${sources[@]}"; do if [ -f "$src" ]; then echo "✓ $src exists ($(wc -l < "$src") lines)" else echo "✗ $src missing" all_exist=false fi done if [ "$all_exist" = false ]; then echo "Cannot proceed - missing source files" exit 1 fi echo echo "=== Checking Include Directories ===" if [ -d "csrc" ]; then echo "✓ csrc include directory exists" echo "Headers in csrc/:" find csrc -name "*.h" -o -name "*.hpp" | head -10 else echo "✗ csrc include directory missing" fi echo echo "=== Testing Individual Source Compilation ===" # Test compiling each .cu file individually for src in csrc/*.cu; do if [ -f "$src" ]; then echo "Testing compilation of $(basename "$src")..." if timeout 60 hipcc -c "$src" -o "/tmp/$(basename "$src" .cu).o" \ --amdgpu-target=gfx942 \ -I./csrc \ -I"$(python3 -c 'import torch; print(torch.utils.cpp_extension.include_paths()[0])')" \ -std=c++17 \ -O3 \ -fPIC; then echo "✓ $(basename "$src") compiled successfully" else echo "✗ $(basename "$src") compilation failed" fi fi done echo echo "=== Testing grouped_gemm.cu Specifically ===" echo "This is often the most complex kernel..." if timeout 120 hipcc -c csrc/grouped_gemm/grouped_gemm.cu -o /tmp/grouped_gemm.o \ --amdgpu-target=gfx942 \ -I./csrc \ -I"$(python3 -c 'import torch; print(torch.utils.cpp_extension.include_paths()[0])')" \ -std=c++17 \ -O3 \ -fPIC \ -lhipblaslt \ -v; then echo "✓ grouped_gemm.cu compiled successfully" else echo "✗ grouped_gemm.cu compilation failed" fi echo echo "=== Testing torch_binding.cpp ===" if timeout 60 hipcc -c torch-ext/torch_binding.cpp -o /tmp/torch_binding.o \ -I./csrc \ -I"$(python3 -c 'import torch; print(torch.utils.cpp_extension.include_paths()[0])')" \ -std=c++17 \ -O3 \ -fPIC; then echo "✓ torch_binding.cpp compiled successfully" else echo "✗ torch_binding.cpp compilation failed" fi echo echo "=== Testing Incremental PyTorch Extension Build ===" cat > debug_build.py << 'EOF' import os import pathlib import sys import signal import time from torch.utils.cpp_extension import load def timeout_handler(signum, frame): print("Build timed out - this indicates a hanging issue") sys.exit(1) # Set up timeout signal.signal(signal.SIGALRM, timeout_handler) signal.alarm(180) # 3 minute timeout repo = pathlib.Path(".").resolve() os.environ.setdefault("TORCH_EXTENSIONS_DIR", str(repo / ".torch_extensions_debug")) print("=== Testing with Single Source File ===") try: print("Building with just new_cumsum.cu...") mod = load( name="_megablocks_debug_single", sources=["csrc/new_cumsum.cu"], extra_include_paths=["csrc"], extra_cflags=["-O3", "-std=c++17"], extra_cuda_cflags=["-O3"], verbose=True, is_python_module=False, ) print("✓ Single source build successful") except Exception as e: print(f"✗ Single source build failed: {e}") print("\n=== Testing with Two Source Files ===") try: print("Building with new_cumsum.cu and new_histogram.cu...") mod = load( name="_megablocks_debug_double", sources=["csrc/new_cumsum.cu", "csrc/new_histogram.cu"], extra_include_paths=["csrc"], extra_cflags=["-O3", "-std=c++17"], extra_cuda_cflags=["-O3"], verbose=True, is_python_module=False, ) print("✓ Double source build successful") except Exception as e: print(f"✗ Double source build failed: {e}") print("\n=== Testing with grouped_gemm.cu Only ===") try: print("Building with just grouped_gemm.cu (most complex)...") mod = load( name="_megablocks_debug_gemm", sources=["csrc/grouped_gemm/grouped_gemm.cu"], extra_include_paths=["csrc"], extra_cflags=["-O3", "-std=c++17"], extra_cuda_cflags=["-O3"], extra_ldflags=["-lhipblaslt"], verbose=True, is_python_module=False, ) print("✓ grouped_gemm build successful") except Exception as e: print(f"✗ grouped_gemm build failed: {e}") signal.alarm(0) # Cancel timeout EOF echo "Running incremental build test..." python3 debug_build.py echo echo "=== Testing Full Build with Timeout ===" cat > debug_full_build.py << 'EOF' import os import pathlib import sys import signal from torch.utils.cpp_extension import load def timeout_handler(signum, frame): print("Full build timed out - this confirms the hanging issue") sys.exit(124) # timeout exit code # Set up 5 minute timeout signal.signal(signal.SIGALRM, timeout_handler) signal.alarm(300) repo = pathlib.Path(".").resolve() os.environ.setdefault("TORCH_EXTENSIONS_DIR", str(repo / ".torch_extensions_debug")) sources = [ "torch-ext/torch_binding.cpp", "csrc/new_cumsum.cu", "csrc/new_histogram.cu", "csrc/new_indices.cu", "csrc/new_replicate.cu", "csrc/new_sort.cu", "csrc/grouped_gemm/grouped_gemm.cu", ] print("=== Attempting Full MegaBlocks Build ===") print("This mimics the exact build.py process...") print("Sources:", sources) try: mod = load( name="_megablocks_debug_full", sources=sources, extra_include_paths=["csrc"], extra_cflags=["-O3", "-std=c++17"], extra_cuda_cflags=["-O3"], extra_ldflags=["-lhipblaslt"], verbose=True, is_python_module=False, ) print("✓ Full build successful!") print("Built:", mod) except Exception as e: print(f"✗ Full build failed: {e}") import traceback traceback.print_exc() signal.alarm(0) EOF echo "Running full build test (with timeout)..." python3 debug_full_build.py echo echo "=== Cleanup ===" rm -f /tmp/*.o rm -f debug_build.py debug_full_build.py echo echo "=== Debug Script 4 Complete ==="