Spaces:
Running
Running
File size: 16,442 Bytes
94b8607 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 |
# Copyright (C) 2025 Arcee AI
# SPDX-License-Identifier: LGPL-3.0-only
"""
Module for computational graph execution.
Classes:
Task: Abstract base class representing a computational task.
Executor: Class for scheduling and executing directed acyclic task graphs.
"""
import os
import sys
import gc
import logging
import networkx
import torch
import tqdm
from pydantic import BaseModel
from typing_extensions import Generic, TypeVar
from abc import ABC, abstractmethod
from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union
from mergekit.common import get_torch_accelerator_module
# Windows/NVIDIA specific allocator tuning to reduce fragmentation
if sys.platform == "win32":
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:32"
ValueT = TypeVar("ValueT")
LOG = logging.getLogger(__name__)
class Task(ABC, BaseModel, Generic[ValueT], frozen=True):
@abstractmethod
def arguments(self) -> Dict[str, "Task"]:
...
@abstractmethod
def execute(self, **kwargs) -> ValueT:
...
def priority(self) -> int:
return 0
def group_label(self) -> Optional[str]:
return None
def uses_accelerator(self) -> bool:
return False
def main_thread_only(self) -> bool:
return False
def duplicate_per_gpu(self) -> bool:
return False
class TaskUniverse:
tasks: List[Task]
task_to_index: Dict[Task, int]
task_arguments: Dict[int, Dict[str, int]]
_type_id_to_index: Dict[Tuple[type, int], int]
def __init__(self, tasks: Optional[Iterable[Task]] = None):
self.tasks = []
self.task_to_index = {}
self.task_arguments = {}
self._type_id_to_index = {}
if tasks is not None:
for task in tasks:
self.add_task(task)
def add_task(self, task: Task, recursive: bool = True) -> "TaskHandle":
_ti_key = (type(task), id(task))
if _ti_key in self._type_id_to_index:
index = self._type_id_to_index[_ti_key]
return TaskHandle(self, index)
index = self.task_to_index.setdefault(task, len(self.tasks))
if index < len(self.tasks):
return TaskHandle(self, index)
self.tasks.append(task)
self._type_id_to_index[_ti_key] = index
if recursive:
self.task_arguments[index] = {}
for k, v in task.arguments().items():
self.task_arguments[index][k] = self.add_task(v, recursive=True)._index
return TaskHandle(self, index)
def get_handle(self, task: Task) -> Optional["TaskHandle"]:
if task not in self.task_to_index:
return None
return TaskHandle(self, self.task_to_index[task])
class TaskHandle:
__slots__ = ["_universe", "_index"]
_universe: TaskUniverse
_index: int
def __init__(self, universe: TaskUniverse, index: int):
self._universe = universe
self._index = index
def task(self) -> Task:
return self._universe.tasks[self._index]
def arguments(self) -> Dict[str, "TaskHandle"]:
return {
k: TaskHandle(self._universe, v)
for k, v in self._universe.task_arguments[self._index].items()
}
def __eq__(self, other):
if not isinstance(other, TaskHandle):
return False
return self._index == other._index and self._universe is other._universe
def __hash__(self):
return self._index
def __str__(self):
return f"TaskHandle({type(self.task()).__name__}, {self._index})"
__repr__ = __str__
class ExecutionSchedule:
tasks: List[TaskHandle]
last_use_index: Dict[TaskHandle, int]
def __init__(self, tasks: List[TaskHandle], last_use_index: Dict[TaskHandle, int]):
self.tasks = tasks
self.last_use_index = last_use_index
def build_schedule(
targets: List[TaskHandle], cached_values: Dict[TaskHandle, Any]
) -> ExecutionSchedule:
if not targets:
return ExecutionSchedule(tasks=[], last_use_index={})
universe = targets[0]._universe
dummy_handle = TaskHandle(universe, -1)
edge_tups: List[Tuple[TaskHandle, TaskHandle]] = []
explored = set()
to_explore = set(targets)
while to_explore:
task = to_explore.pop()
if task in explored:
continue
explored.add(task)
if task in (cached_values or {}):
continue
for dep in task.arguments().values():
to_explore.add(dep)
edge_tups.append((dep, task))
for target in targets:
edge_tups.append((dummy_handle, target))
def _compare_key(node: TaskHandle) -> Tuple[str, int]:
if node._index < 0:
return ("", 0)
task = node.task()
return (task.group_label() or "", -task.priority())
graph = networkx.DiGraph(edge_tups)
schedule: List[TaskHandle] = [
node
for node in networkx.lexicographical_topological_sort(graph, key=_compare_key)
if (node != dummy_handle) and node not in (cached_values or {})
]
last_use_index = {}
for idx, task in reversed(list(enumerate(schedule))):
for dep in task.arguments().values():
if dep not in last_use_index:
last_use_index[dep] = idx
if task not in last_use_index:
last_use_index[task] = idx
for task in cached_values or {}:
if task not in last_use_index:
last_use_index[task] = len(schedule) + 1
return ExecutionSchedule(tasks=schedule, last_use_index=last_use_index)
class Executor:
math_device: torch.device
storage_device: torch.device
universe: TaskUniverse
targets: List[TaskHandle]
schedule: ExecutionSchedule
cached_values: Optional[Dict[TaskHandle, Any]]
def __init__(
self,
targets: Union[List[Task], List[TaskHandle]],
math_device: torch.device = torch.device("cpu"),
storage_device: torch.device = torch.device("cpu"),
cached_values: Optional[Dict[TaskHandle, Any]] = None,
):
self.cached_values = cached_values
if isinstance(math_device, str):
math_device = torch.device(math_device)
if isinstance(storage_device, str):
storage_device = torch.device(storage_device)
self.math_device = math_device
self.storage_device = storage_device
if targets and isinstance(targets[0], Task):
universe = TaskUniverse(targets)
targets = [universe.add_task(t) for t in targets]
elif targets and isinstance(targets[0], TaskHandle):
universe = targets[0]._universe
elif not targets:
universe = TaskUniverse()
else:
raise ValueError("Targets must be a list of Task or TaskHandle instances")
self.universe = universe
self.targets = targets
self.schedule = build_schedule(targets, cached_values=cached_values)
def _slice_argument(self, arg: Any, start: int, end: int) -> Any:
"""Helper to slice tensors within nested structures."""
if isinstance(arg, torch.Tensor):
# Only slice if the dimension is large enough
if arg.shape[0] > 1:
return arg[start:end]
return arg
elif isinstance(arg, dict):
return {k: self._slice_argument(v, start, end) for k, v in arg.items()}
elif isinstance(arg, list):
return [self._slice_argument(v, start, end) for v in arg]
elif isinstance(arg, tuple):
return tuple(self._slice_argument(v, start, end) for v in arg)
return arg
def _execute_chunked(self, task: Task, arguments: Dict[str, Any], chunk_size: int) -> Any:
"""
Executes a task by splitting input tensors into chunks, processing on GPU,
and concatenating results on CPU.
"""
# Find a reference tensor to determine batch size
ref_tensor = None
for arg in arguments.values():
if isinstance(arg, torch.Tensor):
ref_tensor = arg
break
elif isinstance(arg, dict):
for v in arg.values():
if isinstance(v, torch.Tensor):
ref_tensor = v
break
if ref_tensor is not None: break
if ref_tensor is None:
raise ValueError("No tensors found to chunk")
total_rows = ref_tensor.shape[0]
results = []
accelerator = get_torch_accelerator_module(self.math_device.type) if self.math_device.type != "cpu" else None
# Process in chunks
for i in range(0, total_rows, chunk_size):
end = min(i + chunk_size, total_rows)
# Slice inputs
chunk_args = {
k: self._slice_argument(v, i, end)
for k, v in arguments.items()
}
# Move chunk inputs to GPU
chunk_args_gpu = {
k: self._move_tensors(v, self.math_device)
for k, v in chunk_args.items()
}
# Execute
chunk_res = task.execute(**chunk_args_gpu)
# Move result to CPU immediately
chunk_res_cpu = self._move_tensors(chunk_res, self.storage_device)
results.append(chunk_res_cpu)
# Cleanup
del chunk_args
del chunk_args_gpu
del chunk_res
# Clear cache inside loop to handle complex methods like Magic
if accelerator:
accelerator.empty_cache()
# Concatenate results
if isinstance(results[0], torch.Tensor):
return torch.cat(results, dim=0)
elif isinstance(results[0], dict):
# Reassemble dict of tensors
out = {}
for k in results[0].keys():
out[k] = torch.cat([r[k] for r in results], dim=0)
return out
else:
raise ValueError("Unsupported return type for chunking")
def _run(
self,
quiet: bool = False,
desc: Optional[str] = None,
) -> Iterator[Tuple[TaskHandle, Any]]:
last_use_index = self.schedule.last_use_index
values: Dict[TaskHandle, Any] = {}
if self.cached_values:
for task, value in self.cached_values.items():
values[task] = value
is_gpu_execution = self.math_device.type != "cpu"
accelerator = get_torch_accelerator_module(self.math_device.type) if is_gpu_execution else None
for idx, task_handle in (
pbar := tqdm.tqdm(
list(enumerate(self.schedule.tasks)),
disable=quiet,
desc=desc or "Executing graph",
)
):
task = task_handle.task()
task_type = type(task).__name__
# Heuristic: Don't force I/O tasks to GPU
# PermutedEmbeddings is essentially a gather operation, hard to chunk, better on CPU if memory is tight
is_io_task = task_type in ["LoadTensor", "GatherTensors", "SaveTensor", "TensorWriterTask", "FinalizeModel", "PermutedEmbeddings"]
want_gpu = is_gpu_execution and (task.uses_accelerator() or not is_io_task)
success = False
if want_gpu:
try:
# 1. Try Full GPU Execution
arguments = {}
for name, dep_handle in task_handle.arguments().items():
value = values[dep_handle]
value = self._move_tensors(value, self.math_device)
arguments[name] = value
res = task.execute(**arguments)
del arguments
res = self._move_tensors(res, self.storage_device)
values[task_handle] = res
success = True
except torch.OutOfMemoryError:
# Cleanup
arguments = None
res = None
gc.collect()
if accelerator: accelerator.empty_cache()
# 2. Try Chunked GPU Execution with Adaptive Sizing
chunk_sizes = [4096, 2048, 1024, 512, 256, 128, 64]
# Reload arguments on CPU
arguments = {}
for name, dep_handle in task_handle.arguments().items():
arguments[name] = values[dep_handle] # Already on storage device
for chunk_size in chunk_sizes:
try:
LOG.info(f"OOM on {task_type}. Attempting chunked GPU execution (size={chunk_size})...")
res = self._execute_chunked(task, arguments, chunk_size=chunk_size)
values[task_handle] = res
success = True
LOG.info(f"Chunked execution successful for {task_type} (size={chunk_size})")
break
except Exception as e:
LOG.warning(f"Chunked execution failed at size {chunk_size} ({str(e)}).")
gc.collect()
if accelerator: accelerator.empty_cache()
# If it wasn't an OOM (e.g. index error), stop trying chunking
if not isinstance(e, torch.OutOfMemoryError):
break
# 3. CPU Fallback
if not success:
if want_gpu:
LOG.warning(f"All GPU attempts failed for {task_type}. Falling back to CPU.")
# Ensure we clean up any GPU debris before CPU attempt
if is_gpu_execution:
gc.collect()
if accelerator: accelerator.empty_cache()
arguments = {}
for name, dep_handle in task_handle.arguments().items():
value = values[dep_handle]
value = self._move_tensors(value, torch.device("cpu"))
arguments[name] = value
res = task.execute(**arguments)
del arguments
res = self._move_tensors(res, self.storage_device)
values[task_handle] = res
del res
if task_handle in self.targets:
yield (task_handle, values[task_handle])
# Evict unreferenced values
expired = []
for key in values:
if idx >= last_use_index[key]:
expired.append(key)
for key in expired:
del values[key]
# Aggressive cleanup
if is_gpu_execution:
gc.collect()
if accelerator: accelerator.empty_cache()
del values
del pbar
def run(
self,
quiet: bool = False,
desc: Optional[str] = None,
) -> Iterator[Tuple[Task, Any]]:
for handle, value in self._run(quiet=quiet, desc=desc):
yield (handle.task(), value)
def execute(self, desc: Optional[str] = None) -> None:
for _ in self.run(desc=desc):
pass
def _move_tensors(
self, value: Any, device: torch.device, non_blocking: Optional[bool] = None
) -> Any:
if non_blocking is None:
non_blocking = device.type in ["cuda", "xpu"]
if isinstance(value, torch.Tensor):
if value.device == device:
return value
return value.to(device=device, non_blocking=non_blocking)
elif isinstance(value, dict):
return {k: self._move_tensors(v, device, non_blocking) for k, v in value.items()}
elif isinstance(value, list):
return [self._move_tensors(v, device, non_blocking) for v in value]
elif isinstance(value, tuple):
return tuple(self._move_tensors(v, device, non_blocking) for v in value)
return value |