Spaces:
Running
Running
| # Copyright (C) 2025 Arcee AI | |
| # SPDX-License-Identifier: LGPL-3.0-only | |
| """ | |
| Module for computational graph execution. | |
| Classes: | |
| Task: Abstract base class representing a computational task. | |
| Executor: Class for scheduling and executing directed acyclic task graphs. | |
| """ | |
| import os | |
| import sys | |
| import gc | |
| import logging | |
| import networkx | |
| import torch | |
| import tqdm | |
| from pydantic import BaseModel | |
| from typing_extensions import Generic, TypeVar | |
| from abc import ABC, abstractmethod | |
| from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union | |
| from mergekit.common import get_torch_accelerator_module | |
| # Windows/NVIDIA specific allocator tuning to reduce fragmentation | |
| if sys.platform == "win32": | |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:32" | |
| ValueT = TypeVar("ValueT") | |
| LOG = logging.getLogger(__name__) | |
| class Task(ABC, BaseModel, Generic[ValueT], frozen=True): | |
| def arguments(self) -> Dict[str, "Task"]: | |
| ... | |
| def execute(self, **kwargs) -> ValueT: | |
| ... | |
| def priority(self) -> int: | |
| return 0 | |
| def group_label(self) -> Optional[str]: | |
| return None | |
| def uses_accelerator(self) -> bool: | |
| return False | |
| def main_thread_only(self) -> bool: | |
| return False | |
| def duplicate_per_gpu(self) -> bool: | |
| return False | |
| class TaskUniverse: | |
| tasks: List[Task] | |
| task_to_index: Dict[Task, int] | |
| task_arguments: Dict[int, Dict[str, int]] | |
| _type_id_to_index: Dict[Tuple[type, int], int] | |
| def __init__(self, tasks: Optional[Iterable[Task]] = None): | |
| self.tasks = [] | |
| self.task_to_index = {} | |
| self.task_arguments = {} | |
| self._type_id_to_index = {} | |
| if tasks is not None: | |
| for task in tasks: | |
| self.add_task(task) | |
| def add_task(self, task: Task, recursive: bool = True) -> "TaskHandle": | |
| _ti_key = (type(task), id(task)) | |
| if _ti_key in self._type_id_to_index: | |
| index = self._type_id_to_index[_ti_key] | |
| return TaskHandle(self, index) | |
| index = self.task_to_index.setdefault(task, len(self.tasks)) | |
| if index < len(self.tasks): | |
| return TaskHandle(self, index) | |
| self.tasks.append(task) | |
| self._type_id_to_index[_ti_key] = index | |
| if recursive: | |
| self.task_arguments[index] = {} | |
| for k, v in task.arguments().items(): | |
| self.task_arguments[index][k] = self.add_task(v, recursive=True)._index | |
| return TaskHandle(self, index) | |
| def get_handle(self, task: Task) -> Optional["TaskHandle"]: | |
| if task not in self.task_to_index: | |
| return None | |
| return TaskHandle(self, self.task_to_index[task]) | |
| class TaskHandle: | |
| __slots__ = ["_universe", "_index"] | |
| _universe: TaskUniverse | |
| _index: int | |
| def __init__(self, universe: TaskUniverse, index: int): | |
| self._universe = universe | |
| self._index = index | |
| def task(self) -> Task: | |
| return self._universe.tasks[self._index] | |
| def arguments(self) -> Dict[str, "TaskHandle"]: | |
| return { | |
| k: TaskHandle(self._universe, v) | |
| for k, v in self._universe.task_arguments[self._index].items() | |
| } | |
| def __eq__(self, other): | |
| if not isinstance(other, TaskHandle): | |
| return False | |
| return self._index == other._index and self._universe is other._universe | |
| def __hash__(self): | |
| return self._index | |
| def __str__(self): | |
| return f"TaskHandle({type(self.task()).__name__}, {self._index})" | |
| __repr__ = __str__ | |
| class ExecutionSchedule: | |
| tasks: List[TaskHandle] | |
| last_use_index: Dict[TaskHandle, int] | |
| def __init__(self, tasks: List[TaskHandle], last_use_index: Dict[TaskHandle, int]): | |
| self.tasks = tasks | |
| self.last_use_index = last_use_index | |
| def build_schedule( | |
| targets: List[TaskHandle], cached_values: Dict[TaskHandle, Any] | |
| ) -> ExecutionSchedule: | |
| if not targets: | |
| return ExecutionSchedule(tasks=[], last_use_index={}) | |
| universe = targets[0]._universe | |
| dummy_handle = TaskHandle(universe, -1) | |
| edge_tups: List[Tuple[TaskHandle, TaskHandle]] = [] | |
| explored = set() | |
| to_explore = set(targets) | |
| while to_explore: | |
| task = to_explore.pop() | |
| if task in explored: | |
| continue | |
| explored.add(task) | |
| if task in (cached_values or {}): | |
| continue | |
| for dep in task.arguments().values(): | |
| to_explore.add(dep) | |
| edge_tups.append((dep, task)) | |
| for target in targets: | |
| edge_tups.append((dummy_handle, target)) | |
| def _compare_key(node: TaskHandle) -> Tuple[str, int]: | |
| if node._index < 0: | |
| return ("", 0) | |
| task = node.task() | |
| return (task.group_label() or "", -task.priority()) | |
| graph = networkx.DiGraph(edge_tups) | |
| schedule: List[TaskHandle] = [ | |
| node | |
| for node in networkx.lexicographical_topological_sort(graph, key=_compare_key) | |
| if (node != dummy_handle) and node not in (cached_values or {}) | |
| ] | |
| last_use_index = {} | |
| for idx, task in reversed(list(enumerate(schedule))): | |
| for dep in task.arguments().values(): | |
| if dep not in last_use_index: | |
| last_use_index[dep] = idx | |
| if task not in last_use_index: | |
| last_use_index[task] = idx | |
| for task in cached_values or {}: | |
| if task not in last_use_index: | |
| last_use_index[task] = len(schedule) + 1 | |
| return ExecutionSchedule(tasks=schedule, last_use_index=last_use_index) | |
| class Executor: | |
| math_device: torch.device | |
| storage_device: torch.device | |
| universe: TaskUniverse | |
| targets: List[TaskHandle] | |
| schedule: ExecutionSchedule | |
| cached_values: Optional[Dict[TaskHandle, Any]] | |
| def __init__( | |
| self, | |
| targets: Union[List[Task], List[TaskHandle]], | |
| math_device: torch.device = torch.device("cpu"), | |
| storage_device: torch.device = torch.device("cpu"), | |
| cached_values: Optional[Dict[TaskHandle, Any]] = None, | |
| ): | |
| self.cached_values = cached_values | |
| if isinstance(math_device, str): | |
| math_device = torch.device(math_device) | |
| if isinstance(storage_device, str): | |
| storage_device = torch.device(storage_device) | |
| self.math_device = math_device | |
| self.storage_device = storage_device | |
| if targets and isinstance(targets[0], Task): | |
| universe = TaskUniverse(targets) | |
| targets = [universe.add_task(t) for t in targets] | |
| elif targets and isinstance(targets[0], TaskHandle): | |
| universe = targets[0]._universe | |
| elif not targets: | |
| universe = TaskUniverse() | |
| else: | |
| raise ValueError("Targets must be a list of Task or TaskHandle instances") | |
| self.universe = universe | |
| self.targets = targets | |
| self.schedule = build_schedule(targets, cached_values=cached_values) | |
| def _slice_argument(self, arg: Any, start: int, end: int) -> Any: | |
| """Helper to slice tensors within nested structures.""" | |
| if isinstance(arg, torch.Tensor): | |
| # Only slice if the dimension is large enough | |
| if arg.shape[0] > 1: | |
| return arg[start:end] | |
| return arg | |
| elif isinstance(arg, dict): | |
| return {k: self._slice_argument(v, start, end) for k, v in arg.items()} | |
| elif isinstance(arg, list): | |
| return [self._slice_argument(v, start, end) for v in arg] | |
| elif isinstance(arg, tuple): | |
| return tuple(self._slice_argument(v, start, end) for v in arg) | |
| return arg | |
| def _execute_chunked(self, task: Task, arguments: Dict[str, Any], chunk_size: int) -> Any: | |
| """ | |
| Executes a task by splitting input tensors into chunks, processing on GPU, | |
| and concatenating results on CPU. | |
| """ | |
| # Find a reference tensor to determine batch size | |
| ref_tensor = None | |
| for arg in arguments.values(): | |
| if isinstance(arg, torch.Tensor): | |
| ref_tensor = arg | |
| break | |
| elif isinstance(arg, dict): | |
| for v in arg.values(): | |
| if isinstance(v, torch.Tensor): | |
| ref_tensor = v | |
| break | |
| if ref_tensor is not None: break | |
| if ref_tensor is None: | |
| raise ValueError("No tensors found to chunk") | |
| total_rows = ref_tensor.shape[0] | |
| results = [] | |
| accelerator = get_torch_accelerator_module(self.math_device.type) if self.math_device.type != "cpu" else None | |
| # Process in chunks | |
| for i in range(0, total_rows, chunk_size): | |
| end = min(i + chunk_size, total_rows) | |
| # Slice inputs | |
| chunk_args = { | |
| k: self._slice_argument(v, i, end) | |
| for k, v in arguments.items() | |
| } | |
| # Move chunk inputs to GPU | |
| chunk_args_gpu = { | |
| k: self._move_tensors(v, self.math_device) | |
| for k, v in chunk_args.items() | |
| } | |
| # Execute | |
| chunk_res = task.execute(**chunk_args_gpu) | |
| # Move result to CPU immediately | |
| chunk_res_cpu = self._move_tensors(chunk_res, self.storage_device) | |
| results.append(chunk_res_cpu) | |
| # Cleanup | |
| del chunk_args | |
| del chunk_args_gpu | |
| del chunk_res | |
| # Clear cache inside loop to handle complex methods like Magic | |
| if accelerator: | |
| accelerator.empty_cache() | |
| # Concatenate results | |
| if isinstance(results[0], torch.Tensor): | |
| return torch.cat(results, dim=0) | |
| elif isinstance(results[0], dict): | |
| # Reassemble dict of tensors | |
| out = {} | |
| for k in results[0].keys(): | |
| out[k] = torch.cat([r[k] for r in results], dim=0) | |
| return out | |
| else: | |
| raise ValueError("Unsupported return type for chunking") | |
| def _run( | |
| self, | |
| quiet: bool = False, | |
| desc: Optional[str] = None, | |
| ) -> Iterator[Tuple[TaskHandle, Any]]: | |
| last_use_index = self.schedule.last_use_index | |
| values: Dict[TaskHandle, Any] = {} | |
| if self.cached_values: | |
| for task, value in self.cached_values.items(): | |
| values[task] = value | |
| is_gpu_execution = self.math_device.type != "cpu" | |
| accelerator = get_torch_accelerator_module(self.math_device.type) if is_gpu_execution else None | |
| for idx, task_handle in ( | |
| pbar := tqdm.tqdm( | |
| list(enumerate(self.schedule.tasks)), | |
| disable=quiet, | |
| desc=desc or "Executing graph", | |
| ) | |
| ): | |
| task = task_handle.task() | |
| task_type = type(task).__name__ | |
| # Heuristic: Don't force I/O tasks to GPU | |
| # PermutedEmbeddings is essentially a gather operation, hard to chunk, better on CPU if memory is tight | |
| is_io_task = task_type in ["LoadTensor", "GatherTensors", "SaveTensor", "TensorWriterTask", "FinalizeModel", "PermutedEmbeddings"] | |
| want_gpu = is_gpu_execution and (task.uses_accelerator() or not is_io_task) | |
| success = False | |
| if want_gpu: | |
| try: | |
| # 1. Try Full GPU Execution | |
| arguments = {} | |
| for name, dep_handle in task_handle.arguments().items(): | |
| value = values[dep_handle] | |
| value = self._move_tensors(value, self.math_device) | |
| arguments[name] = value | |
| res = task.execute(**arguments) | |
| del arguments | |
| res = self._move_tensors(res, self.storage_device) | |
| values[task_handle] = res | |
| success = True | |
| except torch.OutOfMemoryError: | |
| # Cleanup | |
| arguments = None | |
| res = None | |
| gc.collect() | |
| if accelerator: accelerator.empty_cache() | |
| # 2. Try Chunked GPU Execution with Adaptive Sizing | |
| chunk_sizes = [4096, 2048, 1024, 512, 256, 128, 64] | |
| # Reload arguments on CPU | |
| arguments = {} | |
| for name, dep_handle in task_handle.arguments().items(): | |
| arguments[name] = values[dep_handle] # Already on storage device | |
| for chunk_size in chunk_sizes: | |
| try: | |
| LOG.info(f"OOM on {task_type}. Attempting chunked GPU execution (size={chunk_size})...") | |
| res = self._execute_chunked(task, arguments, chunk_size=chunk_size) | |
| values[task_handle] = res | |
| success = True | |
| LOG.info(f"Chunked execution successful for {task_type} (size={chunk_size})") | |
| break | |
| except Exception as e: | |
| LOG.warning(f"Chunked execution failed at size {chunk_size} ({str(e)}).") | |
| gc.collect() | |
| if accelerator: accelerator.empty_cache() | |
| # If it wasn't an OOM (e.g. index error), stop trying chunking | |
| if not isinstance(e, torch.OutOfMemoryError): | |
| break | |
| # 3. CPU Fallback | |
| if not success: | |
| if want_gpu: | |
| LOG.warning(f"All GPU attempts failed for {task_type}. Falling back to CPU.") | |
| # Ensure we clean up any GPU debris before CPU attempt | |
| if is_gpu_execution: | |
| gc.collect() | |
| if accelerator: accelerator.empty_cache() | |
| arguments = {} | |
| for name, dep_handle in task_handle.arguments().items(): | |
| value = values[dep_handle] | |
| value = self._move_tensors(value, torch.device("cpu")) | |
| arguments[name] = value | |
| res = task.execute(**arguments) | |
| del arguments | |
| res = self._move_tensors(res, self.storage_device) | |
| values[task_handle] = res | |
| del res | |
| if task_handle in self.targets: | |
| yield (task_handle, values[task_handle]) | |
| # Evict unreferenced values | |
| expired = [] | |
| for key in values: | |
| if idx >= last_use_index[key]: | |
| expired.append(key) | |
| for key in expired: | |
| del values[key] | |
| # Aggressive cleanup | |
| if is_gpu_execution: | |
| gc.collect() | |
| if accelerator: accelerator.empty_cache() | |
| del values | |
| del pbar | |
| def run( | |
| self, | |
| quiet: bool = False, | |
| desc: Optional[str] = None, | |
| ) -> Iterator[Tuple[Task, Any]]: | |
| for handle, value in self._run(quiet=quiet, desc=desc): | |
| yield (handle.task(), value) | |
| def execute(self, desc: Optional[str] = None) -> None: | |
| for _ in self.run(desc=desc): | |
| pass | |
| def _move_tensors( | |
| self, value: Any, device: torch.device, non_blocking: Optional[bool] = None | |
| ) -> Any: | |
| if non_blocking is None: | |
| non_blocking = device.type in ["cuda", "xpu"] | |
| if isinstance(value, torch.Tensor): | |
| if value.device == device: | |
| return value | |
| return value.to(device=device, non_blocking=non_blocking) | |
| elif isinstance(value, dict): | |
| return {k: self._move_tensors(v, device, non_blocking) for k, v in value.items()} | |
| elif isinstance(value, list): | |
| return [self._move_tensors(v, device, non_blocking) for v in value] | |
| elif isinstance(value, tuple): | |
| return tuple(self._move_tensors(v, device, non_blocking) for v in value) | |
| return value |