import os import pathlib import sys import signal import time from torch.utils.cpp_extension import load def timeout_handler(signum, frame): print("Build timed out - this indicates a hanging issue") sys.exit(1) # Set up timeout signal.signal(signal.SIGALRM, timeout_handler) signal.alarm(180) # 3 minute timeout repo = pathlib.Path(".").resolve() os.environ.setdefault("TORCH_EXTENSIONS_DIR", str(repo / ".torch_extensions_debug")) print("=== Testing with Single Source File ===") try: print("Building with just new_cumsum.cu...") mod = load( name="_megablocks_debug_single", sources=["csrc/new_cumsum.cu"], extra_include_paths=["csrc"], extra_cflags=["-O3", "-std=c++17"], extra_cuda_cflags=["-O3"], verbose=True, is_python_module=False, ) print("✓ Single source build successful") except Exception as e: print(f"✗ Single source build failed: {e}") print("\n=== Testing with Two Source Files ===") try: print("Building with new_cumsum.cu and new_histogram.cu...") mod = load( name="_megablocks_debug_double", sources=["csrc/new_cumsum.cu", "csrc/new_histogram.cu"], extra_include_paths=["csrc"], extra_cflags=["-O3", "-std=c++17"], extra_cuda_cflags=["-O3"], verbose=True, is_python_module=False, ) print("✓ Double source build successful") except Exception as e: print(f"✗ Double source build failed: {e}") print("\n=== Testing with grouped_gemm.cu Only ===") try: print("Building with just grouped_gemm.cu (most complex)...") mod = load( name="_megablocks_debug_gemm", sources=["csrc/grouped_gemm/grouped_gemm.cu"], extra_include_paths=["csrc"], extra_cflags=["-O3", "-std=c++17"], extra_cuda_cflags=["-O3"], extra_ldflags=["-lhipblaslt"], verbose=True, is_python_module=False, ) print("✓ grouped_gemm build successful") except Exception as e: print(f"✗ grouped_gemm build failed: {e}") signal.alarm(0) # Cancel timeout