# graph_v18.py - Optimized for 3060 TI (8GB VRAM) and similar low-VRAM GPUs # Copyright (C) 2025 Arcee AI # SPDX-License-Identifier: LGPL-3.0-only """ Module for computational graph execution. Classes: Task: Abstract base class representing a computational task. Executor: Class for scheduling and executing directed acyclic task graphs. """ import os import sys import gc import logging import networkx import torch import tqdm from pydantic import BaseModel from typing_extensions import Generic, TypeVar from abc import ABC, abstractmethod from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union from mergekit.common import get_torch_accelerator_module # ============================================================================ # CONFIGURATION SECTION - TUNE THESE PARAMETERS FOR YOUR GPU # ============================================================================ # --- PRIMARY VRAM TARGETS --- # For 3060 TI (8GB): Start with 7.2-7.4GB. Increase if stable, decrease if OOM. # For 3060 (12GB): Try 10.5-11.0GB # For 4GB cards: Try 3.2-3.5GB TARGET_VRAM_GB = 7.7 # Target VRAM usage in GB (TUNE THIS FIRST) # Safety margin to account for PyTorch overhead and fragmentation # Windows typically needs ~0.8GB, Linux ~0.5GB VRAM_SAFETY_MARGIN_GB = 0.2 # Reduce to 0.5-0.6 on Linux, increase to 1.0 if unstable # --- CUDA MEMORY ALLOCATOR CONFIGURATION --- # Smaller values = less fragmentation but more overhead # 24MB is optimal for 8GB cards, 32MB for 12GB+ cards CUDA_MAX_SPLIT_SIZE_MB = 24 # Options: 16, 24, 32, 64 # --- CHUNK SIZE BEHAVIOR --- # How aggressively to reduce chunk size on OOM (0.5-0.9 range) # Lower = more conservative (slower but safer), Higher = more aggressive CHUNK_REDUCTION_FACTOR = 0.75 # Options: 0.5 (safe), 0.7 (balanced), 0.85 (aggressive) # Minimum chunk size before giving up and falling back to CPU MIN_CHUNK_SIZE = 1 # Usually keep at 1, increase to 4-8 if seeing micro-chunk overhead # Enable power-of-2 alignment for chunk sizes (following measure.py strategy) # This improves memory allocation efficiency ENABLE_POWER_OF_2_ALIGNMENT = True # Set False if causing issues # --- TASK-SPECIFIC MEMORY MULTIPLIERS --- # These control how much extra VRAM to reserve for specific task types # Increase if task OOMs, decrease if underutilizing VRAM TASK_MULTIPLIERS = { "ModelStock": 2.2, # Options: 1.8-2.5 (needs room for pairwise similarities) "Karcher": 3.0, # Options: 2.5-3.5 (iterative, needs working memory) "Consensus": 3.0, # Options: 2.5-3.5 (similar to Karcher) "default": 1.2, # Options: 1.0-1.5 (general tasks) } # --- MEMORY CLEANUP BEHAVIOR --- # Enable aggressive garbage collection and cache clearing # True = slower but more stable, False = faster but may fragment memory ENABLE_AGGRESSIVE_CLEANUP = False # Set False if merges are very stable # How often to force cleanup (every N tasks). 0 = after every task CLEANUP_FREQUENCY = 10 # Options: 0 (always), 1, 2, 5, 10 # --- FALLBACK STRATEGY --- # Fixed chunk sizes to try if adaptive chunking fails # Powers of 2 work best for GPU memory alignment FALLBACK_CHUNK_SIZES = [4096, 2048, 1024, 512, 256, 128, 64, 32, 16, 8, 4, 2] # --- FAST PATH OPTIMIZATION --- # Try to execute entire task at once before chunking # True = faster when it works, False = always chunk (more conservative) ENABLE_FAST_PATH = True # Set False if getting frequent OOM on large tasks # --- TASK ROUTING --- # Tasks that should always run on CPU (typically I/O bound) CPU_ONLY_TASKS = [ "LoadTensor", "GatherTensors", "SaveTensor", "TensorWriterTask", "FinalizeModel", "PermutedEmbeddings", # Gather operations don't benefit from GPU ] # ============================================================================ # END OF CONFIGURATION SECTION # ============================================================================ if sys.platform == "win32": os.environ["PYTORCH_CUDA_ALLOC_CONF"] = f"max_split_size_mb:{CUDA_MAX_SPLIT_SIZE_MB}" ValueT = TypeVar("ValueT") LOG = logging.getLogger(__name__) def _round_to_power_of_2(n: int, prefer_lower: bool = True) -> int: """Round to nearest power of 2 for memory alignment.""" if n <= 0: return 1 if n == 1: return 1 # Find the two nearest powers of 2 power = n.bit_length() - 1 lower = 1 << power upper = 1 << (power + 1) if prefer_lower or (n - lower) < (upper - n): return lower return upper class Task(ABC, BaseModel, Generic[ValueT], frozen=True): @abstractmethod def arguments(self) -> Dict[str, "Task"]: ... @abstractmethod def execute(self, **kwargs) -> ValueT: ... def priority(self) -> int: return 0 def group_label(self) -> Optional[str]: return None def uses_accelerator(self) -> bool: return False def main_thread_only(self) -> bool: return False def duplicate_per_gpu(self) -> bool: return False class TaskUniverse: tasks: List[Task] task_to_index: Dict[Task, int] task_arguments: Dict[int, Dict[str, int]] _type_id_to_index: Dict[Tuple[type, int], int] def __init__(self, tasks: Optional[Iterable[Task]] = None): self.tasks = [] self.task_to_index = {} self.task_arguments = {} self._type_id_to_index = {} if tasks is not None: for task in tasks: self.add_task(task) def add_task(self, task: Task, recursive: bool = True) -> "TaskHandle": _ti_key = (type(task), id(task)) if _ti_key in self._type_id_to_index: index = self._type_id_to_index[_ti_key] return TaskHandle(self, index) index = self.task_to_index.setdefault(task, len(self.tasks)) if index < len(self.tasks): return TaskHandle(self, index) self.tasks.append(task) self._type_id_to_index[_ti_key] = index if recursive: self.task_arguments[index] = {} for k, v in task.arguments().items(): self.task_arguments[index][k] = self.add_task(v, recursive=True)._index return TaskHandle(self, index) def get_handle(self, task: Task) -> Optional["TaskHandle"]: if task not in self.task_to_index: return None return TaskHandle(self, self.task_to_index[task]) class TaskHandle: __slots__ = ["_universe", "_index"] _universe: TaskUniverse _index: int def __init__(self, universe: TaskUniverse, index: int): self._universe = universe self._index = index def task(self) -> Task: return self._universe.tasks[self._index] def arguments(self) -> Dict[str, "TaskHandle"]: return { k: TaskHandle(self._universe, v) for k, v in self._universe.task_arguments[self._index].items() } def __eq__(self, other): if not isinstance(other, TaskHandle): return False return self._index == other._index and self._universe is other._universe def __hash__(self): return self._index def __str__(self): return f"TaskHandle({type(self.task()).__name__}, {self._index})" __repr__ = __str__ class ExecutionSchedule: tasks: List[TaskHandle] last_use_index: Dict[TaskHandle, int] def __init__(self, tasks: List[TaskHandle], last_use_index: Dict[TaskHandle, int]): self.tasks = tasks self.last_use_index = last_use_index def build_schedule( targets: List[TaskHandle], cached_values: Dict[TaskHandle, Any] ) -> ExecutionSchedule: if not targets: return ExecutionSchedule(tasks=[], last_use_index={}) universe = targets[0]._universe dummy_handle = TaskHandle(universe, -1) edge_tups: List[Tuple[TaskHandle, TaskHandle]] = [] explored = set() to_explore = set(targets) while to_explore: task = to_explore.pop() if task in explored: continue explored.add(task) if task in (cached_values or {}): continue for dep in task.arguments().values(): to_explore.add(dep) edge_tups.append((dep, task)) for target in targets: edge_tups.append((dummy_handle, target)) def _compare_key(node: TaskHandle) -> Tuple[str, int]: if node._index < 0: return ("", 0) task = node.task() return (task.group_label() or "", -task.priority()) graph = networkx.DiGraph(edge_tups) schedule: List[TaskHandle] = [ node for node in networkx.lexicographical_topological_sort(graph, key=_compare_key) if (node != dummy_handle) and node not in (cached_values or {}) ] last_use_index = {} for idx, task in reversed(list(enumerate(schedule))): for dep in task.arguments().values(): if dep not in last_use_index: last_use_index[dep] = idx if task not in last_use_index: last_use_index[task] = idx for task in cached_values or {}: if task not in last_use_index: last_use_index[task] = len(schedule) + 1 return ExecutionSchedule(tasks=schedule, last_use_index=last_use_index) class Executor: math_device: torch.device storage_device: torch.device universe: TaskUniverse targets: List[TaskHandle] schedule: ExecutionSchedule cached_values: Optional[Dict[TaskHandle, Any]] _task_counter: int def __init__( self, targets: Union[List[Task], List[TaskHandle]], math_device: torch.device = torch.device("cpu"), storage_device: torch.device = torch.device("cpu"), cached_values: Optional[Dict[TaskHandle, Any]] = None, ): self.cached_values = cached_values self._task_counter = 0 if isinstance(math_device, str): math_device = torch.device(math_device) if isinstance(storage_device, str): storage_device = torch.device(storage_device) self.math_device = math_device self.storage_device = storage_device if targets and isinstance(targets[0], Task): universe = TaskUniverse(targets) targets = [universe.add_task(t) for t in targets] elif targets and isinstance(targets[0], TaskHandle): universe = targets[0]._universe elif not targets: universe = TaskUniverse() else: raise ValueError("Targets must be a list of Task or TaskHandle instances") self.universe = universe self.targets = targets self.schedule = build_schedule(targets, cached_values=cached_values) def _slice_argument(self, arg: Any, start: int, end: int) -> Any: """Recursively slice tensors within nested structures.""" if isinstance(arg, torch.Tensor): if arg.shape[0] > 1: return arg[start:end] return arg elif isinstance(arg, dict): return {k: self._slice_argument(v, start, end) for k, v in arg.items()} elif isinstance(arg, list): return [self._slice_argument(v, start, end) for v in arg] elif isinstance(arg, tuple): return tuple(self._slice_argument(v, start, end) for v in arg) return arg def _get_memory_stats(self) -> Dict[str, float]: """Get current VRAM statistics in GB.""" if self.math_device.type != "cuda": return {} allocated = torch.cuda.memory_allocated(self.math_device) / (1024**3) reserved = torch.cuda.memory_reserved(self.math_device) / (1024**3) total = torch.cuda.get_device_properties(self.math_device).total_memory / (1024**3) return { "allocated_gb": allocated, "reserved_gb": reserved, "total_gb": total, "free_gb": total - allocated, } def _get_adaptive_chunk_size(self, task: Task, arguments: Dict[str, Any]) -> int: """ Calculate optimal chunk size based on available VRAM and task requirements. This implements the "measure.py strategy" of targeting a specific VRAM fill level rather than using currently available memory, which prevents oscillation. """ if self.math_device.type == "cpu": return 1024 # Large default for CPU # Get hardware capacity total_vram = torch.cuda.get_device_properties(self.math_device).total_memory target_bytes = TARGET_VRAM_GB * (1024**3) # Analyze tensor dimensions and count num_tensors = 0 width = 0 bytes_per_element = 4 # Default float32 for arg in arguments.values(): if isinstance(arg, torch.Tensor): num_tensors += 1 width = max(width, arg.shape[-1] if len(arg.shape) > 1 else arg.shape[0]) bytes_per_element = arg.element_size() elif isinstance(arg, dict): for v in arg.values(): if isinstance(v, torch.Tensor): num_tensors += 1 width = max(width, v.shape[-1] if len(v.shape) > 1 else v.shape[0]) bytes_per_element = v.element_size() if num_tensors == 0 or width == 0: return 512 # Safe default # Get task-specific multiplier task_name = type(task).__name__ multiplier = TASK_MULTIPLIERS.get("default", 1.2) for key, mult in TASK_MULTIPLIERS.items(): if key in task_name: multiplier = mult break # Calculate bytes per row with multiplier for working memory bytes_per_row = num_tensors * width * bytes_per_element * multiplier # Calculate usable VRAM (target minus current allocation and safety margin) current_allocated = torch.cuda.memory_allocated(self.math_device) safety_bytes = VRAM_SAFETY_MARGIN_GB * (1024**3) usable_vram = max(target_bytes - current_allocated - safety_bytes, 1024 * (1024**2)) # Calculate chunk size chunk_size = max(MIN_CHUNK_SIZE, int(usable_vram // bytes_per_row)) # Apply power-of-2 alignment if enabled (measure.py strategy) if ENABLE_POWER_OF_2_ALIGNMENT and chunk_size > MIN_CHUNK_SIZE: chunk_size = _round_to_power_of_2(chunk_size, prefer_lower=True) LOG.debug(f"Calculated chunk size: {chunk_size} (tensors={num_tensors}, width={width}, mult={multiplier:.2f})") return chunk_size def _execute_chunked(self, task: Task, arguments: Dict[str, Any]) -> Any: """ Execute task in chunks with progressive fallback strategy. Strategy: 1. Try adaptive chunk size 2. On OOM, reduce by CHUNK_REDUCTION_FACTOR 3. Continue until success or MIN_CHUNK_SIZE reached """ # Find total rows to process total_rows = 0 for arg in arguments.values(): if isinstance(arg, torch.Tensor): total_rows = arg.shape[0] break elif isinstance(arg, dict): for v in arg.values(): if isinstance(v, torch.Tensor): total_rows = v.shape[0] break if total_rows > 0: break if total_rows == 0: return task.execute(**arguments) # Calculate initial chunk size chunk_size = self._get_adaptive_chunk_size(task, arguments) # FAST PATH: Try to execute all at once if chunk size >= total rows if ENABLE_FAST_PATH and chunk_size >= total_rows: try: gpu_args = { k: self._move_tensors(v, self.math_device) for k, v in arguments.items() } res = task.execute(**gpu_args) result = self._move_tensors(res, self.storage_device) del gpu_args, res if ENABLE_AGGRESSIVE_CLEANUP: torch.cuda.empty_cache() return result except torch.OutOfMemoryError: LOG.warning(f"Fast path OOM, falling back to chunking") torch.cuda.empty_cache() gc.collect() chunk_size = max(MIN_CHUNK_SIZE, total_rows // 2) # Chunked execution with progressive reduction results = [] i = 0 oom_count = 0 while i < total_rows: end = min(i + chunk_size, total_rows) try: chunk_args_gpu = { k: self._move_tensors(self._slice_argument(v, i, end), self.math_device) for k, v in arguments.items() } chunk_res = task.execute(**chunk_args_gpu) results.append(self._move_tensors(chunk_res, self.storage_device)) del chunk_args_gpu, chunk_res # Aggressive cleanup per measure.py strategy if ENABLE_AGGRESSIVE_CLEANUP: torch.cuda.empty_cache() i = end # Move to next chunk oom_count = 0 # Reset OOM counter on success except torch.OutOfMemoryError: oom_count += 1 torch.cuda.empty_cache() gc.collect() # Progressive reduction old_chunk = chunk_size chunk_size = max(MIN_CHUNK_SIZE, int(chunk_size * CHUNK_REDUCTION_FACTOR)) # Apply power-of-2 alignment if ENABLE_POWER_OF_2_ALIGNMENT: chunk_size = _round_to_power_of_2(chunk_size, prefer_lower=True) if chunk_size < MIN_CHUNK_SIZE: LOG.error(f"Chunk size below minimum ({MIN_CHUNK_SIZE}), cannot continue") raise LOG.warning( f"OOM at chunk {old_chunk}, reducing to {chunk_size} " f"(attempt {oom_count}, progress: {i}/{total_rows})" ) # Safety: if we OOM too many times, something is wrong if oom_count > 10: LOG.error("Too many OOM errors, giving up") raise # Concatenate results if not results: return None if isinstance(results[0], torch.Tensor): return torch.cat(results, dim=0) elif isinstance(results[0], dict): out = {} for k in results[0].keys(): out[k] = torch.cat([r[k] for r in results], dim=0) return out return results def _execute_with_fallback(self, task: Task, arguments: Dict[str, Any], accelerator) -> Any: """ Execute task with comprehensive fallback strategy. Strategy: 1. Try full GPU execution 2. Try adaptive chunking 3. Try fixed chunk sizes 4. Fall back to CPU """ task_name = type(task).__name__ # Strategy 1: Try full GPU execution for light tasks try: gpu_args = { k: self._move_tensors(v, self.math_device) for k, v in arguments.items() } res = task.execute(**gpu_args) result = self._move_tensors(res, self.storage_device) del gpu_args, res return result except torch.OutOfMemoryError: LOG.debug(f"Full GPU execution failed for {task_name}, trying chunked") torch.cuda.empty_cache() gc.collect() except Exception as e: LOG.warning(f"GPU execution error for {task_name}: {e}") torch.cuda.empty_cache() raise # Strategy 2: Try adaptive chunking try: result = self._execute_chunked(task, arguments) return result except torch.OutOfMemoryError: LOG.warning(f"Adaptive chunking failed for {task_name}, trying fixed sizes") torch.cuda.empty_cache() gc.collect() except Exception as e: LOG.warning(f"Chunking error for {task_name}: {e}") raise # Strategy 3: Try fixed chunk sizes for chunk_size in FALLBACK_CHUNK_SIZES: if chunk_size < MIN_CHUNK_SIZE: continue try: LOG.info(f"Trying fixed chunk size {chunk_size} for {task_name}") # Get total rows total_rows = 0 for arg in arguments.values(): if isinstance(arg, torch.Tensor): total_rows = arg.shape[0] break elif isinstance(arg, dict): for v in arg.values(): if isinstance(v, torch.Tensor): total_rows = v.shape[0] break if total_rows > 0: break if total_rows == 0: break results = [] for i in range(0, total_rows, chunk_size): end = min(i + chunk_size, total_rows) chunk_args = { k: self._slice_argument(v, i, end) for k, v in arguments.items() } chunk_args_gpu = { k: self._move_tensors(v, self.math_device) for k, v in chunk_args.items() } chunk_res = task.execute(**chunk_args_gpu) results.append(self._move_tensors(chunk_res, self.storage_device)) del chunk_args, chunk_args_gpu, chunk_res if ENABLE_AGGRESSIVE_CLEANUP: torch.cuda.empty_cache() if isinstance(results[0], torch.Tensor): return torch.cat(results, dim=0) elif isinstance(results[0], dict): out = {} for k in results[0].keys(): out[k] = torch.cat([r[k] for r in results], dim=0) return out return results except torch.OutOfMemoryError: torch.cuda.empty_cache() gc.collect() continue except Exception as e: LOG.warning(f"Fixed chunk {chunk_size} failed: {e}") break # Strategy 4: CPU fallback LOG.warning(f"All GPU strategies failed for {task_name}, using CPU") raise torch.OutOfMemoryError("Forcing CPU fallback") def _run( self, quiet: bool = False, desc: Optional[str] = None, ) -> Iterator[Tuple[TaskHandle, Any]]: last_use_index = self.schedule.last_use_index values: Dict[TaskHandle, Any] = {} if self.cached_values: for task, value in self.cached_values.items(): values[task] = value is_gpu_execution = self.math_device.type != "cpu" accelerator = get_torch_accelerator_module(self.math_device.type) if is_gpu_execution else None for idx, task_handle in ( pbar := tqdm.tqdm( list(enumerate(self.schedule.tasks)), disable=quiet, desc=desc or "Executing graph", ) ): task = task_handle.task() task_type = type(task).__name__ # Log memory stats periodically if is_gpu_execution and idx % 10 == 0: stats = self._get_memory_stats() LOG.debug( f"Memory: {stats.get('allocated_gb', 0):.2f}GB allocated, " f"{stats.get('free_gb', 0):.2f}GB free of {stats.get('total_gb', 0):.2f}GB" ) # Determine execution strategy is_cpu_only_task = task_type in CPU_ONLY_TASKS want_gpu = is_gpu_execution and task.uses_accelerator() and not is_cpu_only_task # Collect arguments arguments = {k: values[h] for k, h in task_handle.arguments().items()} success = False # Try GPU execution if want_gpu: try: res = self._execute_with_fallback(task, arguments, accelerator) values[task_handle] = res success = True except torch.OutOfMemoryError: LOG.warning(f"All GPU strategies exhausted for {task_type}, falling back to CPU") success = False except Exception as e: LOG.error(f"GPU execution failed for {task_type}: {e}") success = False # Cleanup after GPU attempt if is_gpu_execution and ENABLE_AGGRESSIVE_CLEANUP: gc.collect() if accelerator: accelerator.empty_cache() # CPU fallback if not success: if want_gpu: LOG.info(f"Executing {task_type} on CPU") # Ensure cleanup before CPU execution if is_gpu_execution: gc.collect() if accelerator: accelerator.empty_cache() # Move arguments to CPU cpu_arguments = { k: self._move_tensors(v, torch.device("cpu")) for k, v in arguments.items() } res = task.execute(**cpu_arguments) del cpu_arguments res = self._move_tensors(res, self.storage_device) values[task_handle] = res del res del arguments if task_handle in self.targets: yield (task_handle, values[task_handle]) # Evict unreferenced values expired = [] for key in values: if idx >= last_use_index[key]: expired.append(key) for key in expired: del values[key] # Periodic cleanup (measure.py strategy) self._task_counter += 1 if is_gpu_execution and ENABLE_AGGRESSIVE_CLEANUP: if CLEANUP_FREQUENCY == 0 or self._task_counter % max(1, CLEANUP_FREQUENCY) == 0: gc.collect() if accelerator: accelerator.empty_cache() del values del pbar def run( self, quiet: bool = False, desc: Optional[str] = None, ) -> Iterator[Tuple[Task, Any]]: for handle, value in self._run(quiet=quiet, desc=desc): yield (handle.task(), value) def execute(self, desc: Optional[str] = None) -> None: for _ in self.run(desc=desc): pass def _move_tensors( self, value: Any, device: torch.device, non_blocking: Optional[bool] = None ) -> Any: """Move tensors to specified device, handling nested structures.""" if non_blocking is None: non_blocking = device.type in ["cuda", "xpu"] if isinstance(value, torch.Tensor): if value.device == device: return value return value.to(device=device, non_blocking=non_blocking) elif isinstance(value, dict): return { k: self._move_tensors(v, device, non_blocking) for k, v in value.items() } elif isinstance(value, list): return [self._move_tensors(v, device, non_blocking) for v in value] elif isinstance(value, tuple): return tuple(self._move_tensors(v, device, non_blocking) for v in value) return value