model_tools / graph_v4.py
Naphula's picture
Upload 2 files
94b8607 verified
# Copyright (C) 2025 Arcee AI
# SPDX-License-Identifier: LGPL-3.0-only
"""
Module for computational graph execution.
Classes:
Task: Abstract base class representing a computational task.
Executor: Class for scheduling and executing directed acyclic task graphs.
"""
import os
import sys
import gc
import logging
import networkx
import torch
import tqdm
from pydantic import BaseModel
from typing_extensions import Generic, TypeVar
from abc import ABC, abstractmethod
from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union
from mergekit.common import get_torch_accelerator_module
# Windows/NVIDIA specific allocator tuning to reduce fragmentation
if sys.platform == "win32":
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:32"
ValueT = TypeVar("ValueT")
LOG = logging.getLogger(__name__)
class Task(ABC, BaseModel, Generic[ValueT], frozen=True):
@abstractmethod
def arguments(self) -> Dict[str, "Task"]:
...
@abstractmethod
def execute(self, **kwargs) -> ValueT:
...
def priority(self) -> int:
return 0
def group_label(self) -> Optional[str]:
return None
def uses_accelerator(self) -> bool:
return False
def main_thread_only(self) -> bool:
return False
def duplicate_per_gpu(self) -> bool:
return False
class TaskUniverse:
tasks: List[Task]
task_to_index: Dict[Task, int]
task_arguments: Dict[int, Dict[str, int]]
_type_id_to_index: Dict[Tuple[type, int], int]
def __init__(self, tasks: Optional[Iterable[Task]] = None):
self.tasks = []
self.task_to_index = {}
self.task_arguments = {}
self._type_id_to_index = {}
if tasks is not None:
for task in tasks:
self.add_task(task)
def add_task(self, task: Task, recursive: bool = True) -> "TaskHandle":
_ti_key = (type(task), id(task))
if _ti_key in self._type_id_to_index:
index = self._type_id_to_index[_ti_key]
return TaskHandle(self, index)
index = self.task_to_index.setdefault(task, len(self.tasks))
if index < len(self.tasks):
return TaskHandle(self, index)
self.tasks.append(task)
self._type_id_to_index[_ti_key] = index
if recursive:
self.task_arguments[index] = {}
for k, v in task.arguments().items():
self.task_arguments[index][k] = self.add_task(v, recursive=True)._index
return TaskHandle(self, index)
def get_handle(self, task: Task) -> Optional["TaskHandle"]:
if task not in self.task_to_index:
return None
return TaskHandle(self, self.task_to_index[task])
class TaskHandle:
__slots__ = ["_universe", "_index"]
_universe: TaskUniverse
_index: int
def __init__(self, universe: TaskUniverse, index: int):
self._universe = universe
self._index = index
def task(self) -> Task:
return self._universe.tasks[self._index]
def arguments(self) -> Dict[str, "TaskHandle"]:
return {
k: TaskHandle(self._universe, v)
for k, v in self._universe.task_arguments[self._index].items()
}
def __eq__(self, other):
if not isinstance(other, TaskHandle):
return False
return self._index == other._index and self._universe is other._universe
def __hash__(self):
return self._index
def __str__(self):
return f"TaskHandle({type(self.task()).__name__}, {self._index})"
__repr__ = __str__
class ExecutionSchedule:
tasks: List[TaskHandle]
last_use_index: Dict[TaskHandle, int]
def __init__(self, tasks: List[TaskHandle], last_use_index: Dict[TaskHandle, int]):
self.tasks = tasks
self.last_use_index = last_use_index
def build_schedule(
targets: List[TaskHandle], cached_values: Dict[TaskHandle, Any]
) -> ExecutionSchedule:
if not targets:
return ExecutionSchedule(tasks=[], last_use_index={})
universe = targets[0]._universe
dummy_handle = TaskHandle(universe, -1)
edge_tups: List[Tuple[TaskHandle, TaskHandle]] = []
explored = set()
to_explore = set(targets)
while to_explore:
task = to_explore.pop()
if task in explored:
continue
explored.add(task)
if task in (cached_values or {}):
continue
for dep in task.arguments().values():
to_explore.add(dep)
edge_tups.append((dep, task))
for target in targets:
edge_tups.append((dummy_handle, target))
def _compare_key(node: TaskHandle) -> Tuple[str, int]:
if node._index < 0:
return ("", 0)
task = node.task()
return (task.group_label() or "", -task.priority())
graph = networkx.DiGraph(edge_tups)
schedule: List[TaskHandle] = [
node
for node in networkx.lexicographical_topological_sort(graph, key=_compare_key)
if (node != dummy_handle) and node not in (cached_values or {})
]
last_use_index = {}
for idx, task in reversed(list(enumerate(schedule))):
for dep in task.arguments().values():
if dep not in last_use_index:
last_use_index[dep] = idx
if task not in last_use_index:
last_use_index[task] = idx
for task in cached_values or {}:
if task not in last_use_index:
last_use_index[task] = len(schedule) + 1
return ExecutionSchedule(tasks=schedule, last_use_index=last_use_index)
class Executor:
math_device: torch.device
storage_device: torch.device
universe: TaskUniverse
targets: List[TaskHandle]
schedule: ExecutionSchedule
cached_values: Optional[Dict[TaskHandle, Any]]
def __init__(
self,
targets: Union[List[Task], List[TaskHandle]],
math_device: torch.device = torch.device("cpu"),
storage_device: torch.device = torch.device("cpu"),
cached_values: Optional[Dict[TaskHandle, Any]] = None,
):
self.cached_values = cached_values
if isinstance(math_device, str):
math_device = torch.device(math_device)
if isinstance(storage_device, str):
storage_device = torch.device(storage_device)
self.math_device = math_device
self.storage_device = storage_device
if targets and isinstance(targets[0], Task):
universe = TaskUniverse(targets)
targets = [universe.add_task(t) for t in targets]
elif targets and isinstance(targets[0], TaskHandle):
universe = targets[0]._universe
elif not targets:
universe = TaskUniverse()
else:
raise ValueError("Targets must be a list of Task or TaskHandle instances")
self.universe = universe
self.targets = targets
self.schedule = build_schedule(targets, cached_values=cached_values)
def _slice_argument(self, arg: Any, start: int, end: int) -> Any:
"""Helper to slice tensors within nested structures."""
if isinstance(arg, torch.Tensor):
# Only slice if the dimension is large enough
if arg.shape[0] > 1:
return arg[start:end]
return arg
elif isinstance(arg, dict):
return {k: self._slice_argument(v, start, end) for k, v in arg.items()}
elif isinstance(arg, list):
return [self._slice_argument(v, start, end) for v in arg]
elif isinstance(arg, tuple):
return tuple(self._slice_argument(v, start, end) for v in arg)
return arg
def _execute_chunked(self, task: Task, arguments: Dict[str, Any], chunk_size: int) -> Any:
"""
Executes a task by splitting input tensors into chunks, processing on GPU,
and concatenating results on CPU.
"""
# Find a reference tensor to determine batch size
ref_tensor = None
for arg in arguments.values():
if isinstance(arg, torch.Tensor):
ref_tensor = arg
break
elif isinstance(arg, dict):
for v in arg.values():
if isinstance(v, torch.Tensor):
ref_tensor = v
break
if ref_tensor is not None: break
if ref_tensor is None:
raise ValueError("No tensors found to chunk")
total_rows = ref_tensor.shape[0]
results = []
accelerator = get_torch_accelerator_module(self.math_device.type) if self.math_device.type != "cpu" else None
# Process in chunks
for i in range(0, total_rows, chunk_size):
end = min(i + chunk_size, total_rows)
# Slice inputs
chunk_args = {
k: self._slice_argument(v, i, end)
for k, v in arguments.items()
}
# Move chunk inputs to GPU
chunk_args_gpu = {
k: self._move_tensors(v, self.math_device)
for k, v in chunk_args.items()
}
# Execute
chunk_res = task.execute(**chunk_args_gpu)
# Move result to CPU immediately
chunk_res_cpu = self._move_tensors(chunk_res, self.storage_device)
results.append(chunk_res_cpu)
# Cleanup
del chunk_args
del chunk_args_gpu
del chunk_res
# Clear cache inside loop to handle complex methods like Magic
if accelerator:
accelerator.empty_cache()
# Concatenate results
if isinstance(results[0], torch.Tensor):
return torch.cat(results, dim=0)
elif isinstance(results[0], dict):
# Reassemble dict of tensors
out = {}
for k in results[0].keys():
out[k] = torch.cat([r[k] for r in results], dim=0)
return out
else:
raise ValueError("Unsupported return type for chunking")
def _run(
self,
quiet: bool = False,
desc: Optional[str] = None,
) -> Iterator[Tuple[TaskHandle, Any]]:
last_use_index = self.schedule.last_use_index
values: Dict[TaskHandle, Any] = {}
if self.cached_values:
for task, value in self.cached_values.items():
values[task] = value
is_gpu_execution = self.math_device.type != "cpu"
accelerator = get_torch_accelerator_module(self.math_device.type) if is_gpu_execution else None
for idx, task_handle in (
pbar := tqdm.tqdm(
list(enumerate(self.schedule.tasks)),
disable=quiet,
desc=desc or "Executing graph",
)
):
task = task_handle.task()
task_type = type(task).__name__
# Heuristic: Don't force I/O tasks to GPU
# PermutedEmbeddings is essentially a gather operation, hard to chunk, better on CPU if memory is tight
is_io_task = task_type in ["LoadTensor", "GatherTensors", "SaveTensor", "TensorWriterTask", "FinalizeModel", "PermutedEmbeddings"]
want_gpu = is_gpu_execution and (task.uses_accelerator() or not is_io_task)
success = False
if want_gpu:
try:
# 1. Try Full GPU Execution
arguments = {}
for name, dep_handle in task_handle.arguments().items():
value = values[dep_handle]
value = self._move_tensors(value, self.math_device)
arguments[name] = value
res = task.execute(**arguments)
del arguments
res = self._move_tensors(res, self.storage_device)
values[task_handle] = res
success = True
except torch.OutOfMemoryError:
# Cleanup
arguments = None
res = None
gc.collect()
if accelerator: accelerator.empty_cache()
# 2. Try Chunked GPU Execution with Adaptive Sizing
chunk_sizes = [4096, 2048, 1024, 512, 256, 128, 64]
# Reload arguments on CPU
arguments = {}
for name, dep_handle in task_handle.arguments().items():
arguments[name] = values[dep_handle] # Already on storage device
for chunk_size in chunk_sizes:
try:
LOG.info(f"OOM on {task_type}. Attempting chunked GPU execution (size={chunk_size})...")
res = self._execute_chunked(task, arguments, chunk_size=chunk_size)
values[task_handle] = res
success = True
LOG.info(f"Chunked execution successful for {task_type} (size={chunk_size})")
break
except Exception as e:
LOG.warning(f"Chunked execution failed at size {chunk_size} ({str(e)}).")
gc.collect()
if accelerator: accelerator.empty_cache()
# If it wasn't an OOM (e.g. index error), stop trying chunking
if not isinstance(e, torch.OutOfMemoryError):
break
# 3. CPU Fallback
if not success:
if want_gpu:
LOG.warning(f"All GPU attempts failed for {task_type}. Falling back to CPU.")
# Ensure we clean up any GPU debris before CPU attempt
if is_gpu_execution:
gc.collect()
if accelerator: accelerator.empty_cache()
arguments = {}
for name, dep_handle in task_handle.arguments().items():
value = values[dep_handle]
value = self._move_tensors(value, torch.device("cpu"))
arguments[name] = value
res = task.execute(**arguments)
del arguments
res = self._move_tensors(res, self.storage_device)
values[task_handle] = res
del res
if task_handle in self.targets:
yield (task_handle, values[task_handle])
# Evict unreferenced values
expired = []
for key in values:
if idx >= last_use_index[key]:
expired.append(key)
for key in expired:
del values[key]
# Aggressive cleanup
if is_gpu_execution:
gc.collect()
if accelerator: accelerator.empty_cache()
del values
del pbar
def run(
self,
quiet: bool = False,
desc: Optional[str] = None,
) -> Iterator[Tuple[Task, Any]]:
for handle, value in self._run(quiet=quiet, desc=desc):
yield (handle.task(), value)
def execute(self, desc: Optional[str] = None) -> None:
for _ in self.run(desc=desc):
pass
def _move_tensors(
self, value: Any, device: torch.device, non_blocking: Optional[bool] = None
) -> Any:
if non_blocking is None:
non_blocking = device.type in ["cuda", "xpu"]
if isinstance(value, torch.Tensor):
if value.device == device:
return value
return value.to(device=device, non_blocking=non_blocking)
elif isinstance(value, dict):
return {k: self._move_tensors(v, device, non_blocking) for k, v in value.items()}
elif isinstance(value, list):
return [self._move_tensors(v, device, non_blocking) for v in value]
elif isinstance(value, tuple):
return tuple(self._move_tensors(v, device, non_blocking) for v in value)
return value