grsdfdf / r1-a /train.py
1f's picture
Add files using upload-large-folder tool
19891ba verified
import torch
import multiprocessing
import time
import os
import sys
def occupy_gpu(device_id, memory_fraction=0.90, compute_size=8192):
"""
Target function for a process to occupy a specific GPU.
Args:
device_id (int): The ID of the GPU to occupy (e.g., 1 for cuda:1).
memory_fraction (float): Fraction of free memory to try and allocate (0.0 to 1.0).
compute_size (int): Dimension of square matrices for matmul compute load.
Larger values increase compute intensity but also use some memory.
"""
try:
# Ensure this process targets the correct GPU
torch.cuda.set_device(device_id)
device = f'cuda:{device_id}'
process_id = os.getpid()
print(f"[PID {process_id}] Targeting {device}...")
# --- 1. Allocate Memory ---
allocated_tensor = None
try:
# Get free memory and total memory
free_mem, total_mem = torch.cuda.mem_get_info(device_id)
target_alloc_bytes = int(free_mem * memory_fraction)
print(f"[PID {process_id}] {device}: Total Mem={total_mem/1024**3:.2f} GB, Free Mem={free_mem/1024**3:.2f} GB")
print(f"[PID {process_id}] {device}: Attempting to allocate ~{target_alloc_bytes/1024**3:.2f} GB ({memory_fraction*100:.0f}% of free)...")
# Calculate tensor size (using float32 = 4 bytes per element)
elements_needed = target_alloc_bytes // 4
# Create a 1D tensor first, as it's simpler to calculate size
allocated_tensor = torch.empty(elements_needed, dtype=torch.float32, device=device)
# Fill it with some data to ensure allocation happens (sometimes lazy allocation occurs)
allocated_tensor.fill_(1.0)
torch.cuda.synchronize(device_id) # Wait for allocation to complete
# Verify allocated memory (this is approximate as PyTorch reserves some overhead)
allocated_bytes = allocated_tensor.nelement() * allocated_tensor.element_size()
print(f"[PID {process_id}] {device}: Successfully allocated tensor using ~{allocated_bytes/1024**3:.2f} GB.")
# Keep the tensor alive by referencing it
except RuntimeError as e:
print(f"[PID {process_id}] {device}: ERROR allocating memory - {e}. Memory usage might be lower.")
print(f"[PID {process_id}] {device}: Check if {memory_fraction*100:.0f}% is too high or other processes are using memory.")
# Continue to compute loop even if memory allocation failed partially or fully
# --- 2. Run Compute Load ---
print(f"[PID {process_id}] {device}: Starting compute loop (matmul {compute_size}x{compute_size})...")
# Create tensors for computation
try:
a = torch.randn(compute_size, compute_size, dtype=torch.float32, device=device)
b = torch.randn(compute_size, compute_size, dtype=torch.float32, device=device)
except RuntimeError as e:
print(f"[PID {process_id}] {device}: ERROR creating compute tensors ({compute_size}x{compute_size}) - {e}.")
print(f"[PID {process_id}] {device}: GPU might not have enough remaining memory for this compute size. Try reducing 'compute_size'. Exiting process.")
return # Exit this process if we can't even create compute tensors
# Infinite compute loop
while True:
# Perform a compute-intensive operation
c = torch.matmul(a, b)
# Optional: add more operations if matmul alone isn't maxing out utilization
# a = a * 1.0001 # Avoid values growing too large/small quickly
# b = b + 0.0001
# torch.cuda.synchronize(device_id) # Usually not needed in a tight loop like this
# We don't need to do anything with 'c', the goal is just the computation.
# No sleep here, we want maximum utilization.
except Exception as e:
print(f"[PID {process_id}] {device}: UNEXPECTED ERROR - {e}")
# Log any other errors that might occur
if __name__ == "__main__":
# --- Configuration ---
TARGET_GPU_IDS = [0,1] # <<< Your target GPU IDs here (cuda:1, cuda:2, cuda:3)
MEMORY_FRACTION_TO_USE = 0.85 # <<< Try to use 90% of *free* memory. Adjust if needed (0.8 to 0.95 is typical)
COMPUTE_MATRIX_DIM = 8192 # <<< Dimension for matmul (e.g., 8192, 10240, 12288).
# Larger = more compute intensive bursts, but uses more temp memory.
# Adjust based on GPU capability and remaining memory after allocation.
# --- End Configuration ---
# Check CUDA availability and device count
if not torch.cuda.is_available():
print("Error: CUDA is not available. Please check your PyTorch installation and CUDA drivers.")
sys.exit(1)
num_gpus = torch.cuda.device_count()
print(f"Found {num_gpus} CUDA devices.")
valid_target_gpus = []
for gpu_id in TARGET_GPU_IDS:
if gpu_id < 0 or gpu_id >= num_gpus:
print(f"Warning: GPU ID {gpu_id} is invalid (must be between 0 and {num_gpus-1}). Skipping.")
else:
valid_target_gpus.append(gpu_id)
if not valid_target_gpus:
print("Error: No valid target GPUs specified or available. Exiting.")
sys.exit(1)
print(f"Attempting to occupy GPUs: {valid_target_gpus}")
print(f"Memory target: {MEMORY_FRACTION_TO_USE*100:.0f}% of free memory per GPU.")
print(f"Compute load: Matrix multiplication of size {COMPUTE_MATRIX_DIM}x{COMPUTE_MATRIX_DIM}.")
print("-" * 30)
# Set multiprocessing start method (important for CUDA in some environments)
try:
multiprocessing.set_start_method('spawn', force=True)
except RuntimeError:
print("Note: Could not set multiprocessing start method to 'spawn'. Using default.")
pass
processes = []
for gpu_id in valid_target_gpus:
p = multiprocessing.Process(target=occupy_gpu, args=(gpu_id, MEMORY_FRACTION_TO_USE, COMPUTE_MATRIX_DIM))
processes.append(p)
p.start()
print("\nProcesses started. Monitor GPU usage with 'nvidia-smi'.")
print("Press Ctrl+C to stop the script and terminate processes.")
try:
# Keep the main script alive while child processes run
for p in processes:
p.join() # Wait for processes to finish (they won't unless error or terminated)
except KeyboardInterrupt:
print("\nCtrl+C detected. Terminating GPU occupation processes...")
for p in processes:
if p.is_alive():
p.terminate() # Send SIGTERM
p.join(timeout=5) # Wait max 5 seconds for graceful exit
if p.is_alive():
print(f"Process {p.pid} did not terminate gracefully, killing.")
p.kill() # Send SIGKILL if necessary
p.join() # Wait for kill
print("All processes terminated.")
except Exception as main_e:
print(f"An error occurred in the main process: {main_e}")
# Optionally try to clean up child processes here too
for p in processes:
if p.is_alive(): p.terminate()