Naphula commited on
Commit
630c8c7
·
verified ·
1 Parent(s): ec5ad0d

Upload 3 files

Browse files
graph_v18_runpod_3090.py ADDED
@@ -0,0 +1,776 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # graph_v18.py - Optimized for 3090 runpod
2
+ # Copyright (C) 2025 Arcee AI
3
+ # SPDX-License-Identifier: LGPL-3.0-only
4
+ """
5
+ Module for computational graph execution.
6
+
7
+ Classes:
8
+ Task: Abstract base class representing a computational task.
9
+ Executor: Class for scheduling and executing directed acyclic task graphs.
10
+ """
11
+
12
+ import os
13
+ import sys
14
+ import gc
15
+ import logging
16
+ import networkx
17
+ import torch
18
+ import tqdm
19
+ from pydantic import BaseModel
20
+ from typing_extensions import Generic, TypeVar
21
+ from abc import ABC, abstractmethod
22
+ from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union
23
+
24
+ from mergekit.common import get_torch_accelerator_module
25
+
26
+ # ============================================================================
27
+ # CONFIGURATION SECTION - TUNE THESE PARAMETERS FOR YOUR GPU
28
+ # ============================================================================
29
+
30
+ # --- PRIMARY VRAM TARGETS ---
31
+ # For 3060 TI (8GB): Start with 7.2-7.4GB. Increase if stable, decrease if OOM.
32
+ # For 3060 (12GB): Try 10.5-11.0GB
33
+ # For 4GB cards: Try 3.2-3.5GB
34
+ TARGET_VRAM_GB = 21.8 # Target VRAM usage in GB (TUNE THIS FIRST)
35
+
36
+ # Safety margin to account for PyTorch overhead and fragmentation
37
+ # Windows typically needs ~0.8GB, Linux ~0.5GB
38
+ VRAM_SAFETY_MARGIN_GB = 1.2 # Reduce to 0.5-0.6 on Linux, increase to 1.0 if unstable
39
+
40
+ # --- CUDA MEMORY ALLOCATOR CONFIGURATION ---
41
+ # Smaller values = less fragmentation but more overhead
42
+ # 24MB is optimal for 8GB cards, 32MB for 12GB+ cards
43
+ CUDA_MAX_SPLIT_SIZE_MB = 24 # Options: 16, 24, 32, 64
44
+
45
+ # --- CHUNK SIZE BEHAVIOR ---
46
+ # How aggressively to reduce chunk size on OOM (0.5-0.9 range)
47
+ # Lower = more conservative (slower but safer), Higher = more aggressive
48
+ CHUNK_REDUCTION_FACTOR = 0.75 # Options: 0.5 (safe), 0.7 (balanced), 0.85 (aggressive)
49
+
50
+ # Minimum chunk size before giving up and falling back to CPU
51
+ MIN_CHUNK_SIZE = 1 # Usually keep at 1, increase to 4-8 if seeing micro-chunk overhead
52
+
53
+ # Enable power-of-2 alignment for chunk sizes (following measure.py strategy)
54
+ # This improves memory allocation efficiency
55
+ ENABLE_POWER_OF_2_ALIGNMENT = True # Set False if causing issues
56
+
57
+ # --- TASK-SPECIFIC MEMORY MULTIPLIERS ---
58
+ # These control how much extra VRAM to reserve for specific task types
59
+ # Increase if task OOMs, decrease if underutilizing VRAM
60
+ TASK_MULTIPLIERS = {
61
+ "ModelStock": 2.0,
62
+ "Karcher": 3.0,
63
+ "Consensus": 3.0,
64
+ "Prometheus": 6.0, # Forces the 3090 to start with ~8k chunks instead of 65k.
65
+ "default": 1.2,
66
+ }
67
+
68
+ # --- MEMORY CLEANUP BEHAVIOR ---
69
+ # Enable aggressive garbage collection and cache clearing
70
+ # True = slower but more stable, False = faster but may fragment memory
71
+ ENABLE_AGGRESSIVE_CLEANUP = True # Set False if merges are very stable
72
+
73
+ # How often to force cleanup (every N tasks). 0 = after every task
74
+ CLEANUP_FREQUENCY = 2 # Options: 0 (always), 1, 2, 5, 10
75
+
76
+ # --- FALLBACK STRATEGY ---
77
+ # Fixed chunk sizes to try if adaptive chunking fails
78
+ # Powers of 2 work best for GPU memory alignment
79
+ FALLBACK_CHUNK_SIZES = [16384, 8192, 4096, 2048, 1024, 512, 256, 128, 64, 32, 16, 8, 4, 2]
80
+
81
+ # --- FAST PATH OPTIMIZATION ---
82
+ # Try to execute entire task at once before chunking
83
+ # True = faster when it works, False = always chunk (more conservative)
84
+ ENABLE_FAST_PATH = True # Set False if getting frequent OOM on large tasks
85
+
86
+ # --- TASK ROUTING ---
87
+ # Tasks that should always run on CPU (typically I/O bound)
88
+ CPU_ONLY_TASKS = [
89
+ "LoadTensor",
90
+ "GatherTensors",
91
+ "SaveTensor",
92
+ "TensorWriterTask",
93
+ "FinalizeModel",
94
+ "PermutedEmbeddings", # Gather operations don't benefit from GPU
95
+ ]
96
+
97
+ # ============================================================================
98
+ # END OF CONFIGURATION SECTION
99
+ # ============================================================================
100
+
101
+ if sys.platform == "win32":
102
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = f"max_split_size_mb:{CUDA_MAX_SPLIT_SIZE_MB}"
103
+
104
+ ValueT = TypeVar("ValueT")
105
+ LOG = logging.getLogger(__name__)
106
+
107
+
108
+ def _round_to_power_of_2(n: int, prefer_lower: bool = True) -> int:
109
+ """Round to nearest power of 2 for memory alignment."""
110
+ if n <= 0:
111
+ return 1
112
+ if n == 1:
113
+ return 1
114
+
115
+ # Find the two nearest powers of 2
116
+ power = n.bit_length() - 1
117
+ lower = 1 << power
118
+ upper = 1 << (power + 1)
119
+
120
+ if prefer_lower or (n - lower) < (upper - n):
121
+ return lower
122
+ return upper
123
+
124
+
125
+ class Task(ABC, BaseModel, Generic[ValueT], frozen=True):
126
+ @abstractmethod
127
+ def arguments(self) -> Dict[str, "Task"]:
128
+ ...
129
+
130
+ @abstractmethod
131
+ def execute(self, **kwargs) -> ValueT:
132
+ ...
133
+
134
+ def priority(self) -> int:
135
+ return 0
136
+
137
+ def group_label(self) -> Optional[str]:
138
+ return None
139
+
140
+ def uses_accelerator(self) -> bool:
141
+ return False
142
+
143
+ def main_thread_only(self) -> bool:
144
+ return False
145
+
146
+ def duplicate_per_gpu(self) -> bool:
147
+ return False
148
+
149
+
150
+ class TaskUniverse:
151
+ tasks: List[Task]
152
+ task_to_index: Dict[Task, int]
153
+ task_arguments: Dict[int, Dict[str, int]]
154
+ _type_id_to_index: Dict[Tuple[type, int], int]
155
+
156
+ def __init__(self, tasks: Optional[Iterable[Task]] = None):
157
+ self.tasks = []
158
+ self.task_to_index = {}
159
+ self.task_arguments = {}
160
+ self._type_id_to_index = {}
161
+ if tasks is not None:
162
+ for task in tasks:
163
+ self.add_task(task)
164
+
165
+ def add_task(self, task: Task, recursive: bool = True) -> "TaskHandle":
166
+ _ti_key = (type(task), id(task))
167
+ if _ti_key in self._type_id_to_index:
168
+ index = self._type_id_to_index[_ti_key]
169
+ return TaskHandle(self, index)
170
+
171
+ index = self.task_to_index.setdefault(task, len(self.tasks))
172
+ if index < len(self.tasks):
173
+ return TaskHandle(self, index)
174
+ self.tasks.append(task)
175
+ self._type_id_to_index[_ti_key] = index
176
+
177
+ if recursive:
178
+ self.task_arguments[index] = {}
179
+ for k, v in task.arguments().items():
180
+ self.task_arguments[index][k] = self.add_task(v, recursive=True)._index
181
+ return TaskHandle(self, index)
182
+
183
+ def get_handle(self, task: Task) -> Optional["TaskHandle"]:
184
+ if task not in self.task_to_index:
185
+ return None
186
+ return TaskHandle(self, self.task_to_index[task])
187
+
188
+
189
+ class TaskHandle:
190
+ __slots__ = ["_universe", "_index"]
191
+ _universe: TaskUniverse
192
+ _index: int
193
+
194
+ def __init__(self, universe: TaskUniverse, index: int):
195
+ self._universe = universe
196
+ self._index = index
197
+
198
+ def task(self) -> Task:
199
+ return self._universe.tasks[self._index]
200
+
201
+ def arguments(self) -> Dict[str, "TaskHandle"]:
202
+ return {
203
+ k: TaskHandle(self._universe, v)
204
+ for k, v in self._universe.task_arguments[self._index].items()
205
+ }
206
+
207
+ def __eq__(self, other):
208
+ if not isinstance(other, TaskHandle):
209
+ return False
210
+ return self._index == other._index and self._universe is other._universe
211
+
212
+ def __hash__(self):
213
+ return self._index
214
+
215
+ def __str__(self):
216
+ return f"TaskHandle({type(self.task()).__name__}, {self._index})"
217
+
218
+ __repr__ = __str__
219
+
220
+
221
+ class ExecutionSchedule:
222
+ tasks: List[TaskHandle]
223
+ last_use_index: Dict[TaskHandle, int]
224
+
225
+ def __init__(self, tasks: List[TaskHandle], last_use_index: Dict[TaskHandle, int]):
226
+ self.tasks = tasks
227
+ self.last_use_index = last_use_index
228
+
229
+
230
+ def build_schedule(
231
+ targets: List[TaskHandle], cached_values: Dict[TaskHandle, Any]
232
+ ) -> ExecutionSchedule:
233
+ if not targets:
234
+ return ExecutionSchedule(tasks=[], last_use_index={})
235
+
236
+ universe = targets[0]._universe
237
+ dummy_handle = TaskHandle(universe, -1)
238
+ edge_tups: List[Tuple[TaskHandle, TaskHandle]] = []
239
+
240
+ explored = set()
241
+ to_explore = set(targets)
242
+ while to_explore:
243
+ task = to_explore.pop()
244
+ if task in explored:
245
+ continue
246
+ explored.add(task)
247
+ if task in (cached_values or {}):
248
+ continue
249
+ for dep in task.arguments().values():
250
+ to_explore.add(dep)
251
+ edge_tups.append((dep, task))
252
+
253
+ for target in targets:
254
+ edge_tups.append((dummy_handle, target))
255
+
256
+ def _compare_key(node: TaskHandle) -> Tuple[str, int]:
257
+ if node._index < 0:
258
+ return ("", 0)
259
+ task = node.task()
260
+ return (task.group_label() or "", -task.priority())
261
+
262
+ graph = networkx.DiGraph(edge_tups)
263
+ schedule: List[TaskHandle] = [
264
+ node
265
+ for node in networkx.lexicographical_topological_sort(graph, key=_compare_key)
266
+ if (node != dummy_handle) and node not in (cached_values or {})
267
+ ]
268
+
269
+ last_use_index = {}
270
+ for idx, task in reversed(list(enumerate(schedule))):
271
+ for dep in task.arguments().values():
272
+ if dep not in last_use_index:
273
+ last_use_index[dep] = idx
274
+ if task not in last_use_index:
275
+ last_use_index[task] = idx
276
+ for task in cached_values or {}:
277
+ if task not in last_use_index:
278
+ last_use_index[task] = len(schedule) + 1
279
+
280
+ return ExecutionSchedule(tasks=schedule, last_use_index=last_use_index)
281
+
282
+
283
+ class Executor:
284
+ math_device: torch.device
285
+ storage_device: torch.device
286
+ universe: TaskUniverse
287
+ targets: List[TaskHandle]
288
+ schedule: ExecutionSchedule
289
+ cached_values: Optional[Dict[TaskHandle, Any]]
290
+ _task_counter: int
291
+
292
+ def __init__(
293
+ self,
294
+ targets: Union[List[Task], List[TaskHandle]],
295
+ math_device: torch.device = torch.device("cpu"),
296
+ storage_device: torch.device = torch.device("cpu"),
297
+ cached_values: Optional[Dict[TaskHandle, Any]] = None,
298
+ ):
299
+ self.cached_values = cached_values
300
+ self._task_counter = 0
301
+
302
+ if isinstance(math_device, str):
303
+ math_device = torch.device(math_device)
304
+ if isinstance(storage_device, str):
305
+ storage_device = torch.device(storage_device)
306
+ self.math_device = math_device
307
+ self.storage_device = storage_device
308
+
309
+ if targets and isinstance(targets[0], Task):
310
+ universe = TaskUniverse(targets)
311
+ targets = [universe.add_task(t) for t in targets]
312
+ elif targets and isinstance(targets[0], TaskHandle):
313
+ universe = targets[0]._universe
314
+ elif not targets:
315
+ universe = TaskUniverse()
316
+ else:
317
+ raise ValueError("Targets must be a list of Task or TaskHandle instances")
318
+
319
+ self.universe = universe
320
+ self.targets = targets
321
+ self.schedule = build_schedule(targets, cached_values=cached_values)
322
+
323
+ def _slice_argument(self, arg: Any, start: int, end: int) -> Any:
324
+ """Recursively slice tensors within nested structures."""
325
+ if isinstance(arg, torch.Tensor):
326
+ if arg.shape[0] > 1:
327
+ return arg[start:end]
328
+ return arg
329
+ elif isinstance(arg, dict):
330
+ return {k: self._slice_argument(v, start, end) for k, v in arg.items()}
331
+ elif isinstance(arg, list):
332
+ return [self._slice_argument(v, start, end) for v in arg]
333
+ elif isinstance(arg, tuple):
334
+ return tuple(self._slice_argument(v, start, end) for v in arg)
335
+ return arg
336
+
337
+ def _get_memory_stats(self) -> Dict[str, float]:
338
+ """Get current VRAM statistics in GB."""
339
+ if self.math_device.type != "cuda":
340
+ return {}
341
+
342
+ allocated = torch.cuda.memory_allocated(self.math_device) / (1024**3)
343
+ reserved = torch.cuda.memory_reserved(self.math_device) / (1024**3)
344
+ total = torch.cuda.get_device_properties(self.math_device).total_memory / (1024**3)
345
+
346
+ return {
347
+ "allocated_gb": allocated,
348
+ "reserved_gb": reserved,
349
+ "total_gb": total,
350
+ "free_gb": total - allocated,
351
+ }
352
+
353
+ def _get_adaptive_chunk_size(self, task: Task, arguments: Dict[str, Any]) -> int:
354
+ """
355
+ Calculate optimal chunk size based on available VRAM and task requirements.
356
+
357
+ This implements the "measure.py strategy" of targeting a specific VRAM fill level
358
+ rather than using currently available memory, which prevents oscillation.
359
+ """
360
+ if self.math_device.type == "cpu":
361
+ return 1024 # Large default for CPU
362
+
363
+ # Get hardware capacity
364
+ total_vram = torch.cuda.get_device_properties(self.math_device).total_memory
365
+ target_bytes = TARGET_VRAM_GB * (1024**3)
366
+
367
+ # Analyze tensor dimensions and count
368
+ num_tensors = 0
369
+ width = 0
370
+ bytes_per_element = 4 # Default float32
371
+
372
+ for arg in arguments.values():
373
+ if isinstance(arg, torch.Tensor):
374
+ num_tensors += 1
375
+ width = max(width, arg.shape[-1] if len(arg.shape) > 1 else arg.shape[0])
376
+ bytes_per_element = arg.element_size()
377
+ elif isinstance(arg, dict):
378
+ for v in arg.values():
379
+ if isinstance(v, torch.Tensor):
380
+ num_tensors += 1
381
+ width = max(width, v.shape[-1] if len(v.shape) > 1 else v.shape[0])
382
+ bytes_per_element = v.element_size()
383
+
384
+ if num_tensors == 0 or width == 0:
385
+ return 512 # Safe default
386
+
387
+ # Get task-specific multiplier
388
+ task_name = type(task).__name__
389
+ multiplier = TASK_MULTIPLIERS.get("default", 1.2)
390
+
391
+ for key, mult in TASK_MULTIPLIERS.items():
392
+ if key in task_name:
393
+ multiplier = mult
394
+ break
395
+
396
+ # Calculate bytes per row with multiplier for working memory
397
+ bytes_per_row = num_tensors * width * bytes_per_element * multiplier
398
+
399
+ # Calculate usable VRAM (target minus current allocation and safety margin)
400
+ current_allocated = torch.cuda.memory_allocated(self.math_device)
401
+ safety_bytes = VRAM_SAFETY_MARGIN_GB * (1024**3)
402
+ usable_vram = max(target_bytes - current_allocated - safety_bytes, 1024 * (1024**2))
403
+
404
+ # Calculate chunk size
405
+ chunk_size = max(MIN_CHUNK_SIZE, int(usable_vram // bytes_per_row))
406
+
407
+ # Apply power-of-2 alignment if enabled (measure.py strategy)
408
+ if ENABLE_POWER_OF_2_ALIGNMENT and chunk_size > MIN_CHUNK_SIZE:
409
+ chunk_size = _round_to_power_of_2(chunk_size, prefer_lower=True)
410
+
411
+ LOG.debug(f"Calculated chunk size: {chunk_size} (tensors={num_tensors}, width={width}, mult={multiplier:.2f})")
412
+ return chunk_size
413
+
414
+ def _execute_chunked(self, task: Task, arguments: Dict[str, Any]) -> Any:
415
+ """
416
+ Execute task in chunks with progressive fallback strategy.
417
+
418
+ Strategy:
419
+ 1. Try adaptive chunk size
420
+ 2. On OOM, reduce by CHUNK_REDUCTION_FACTOR
421
+ 3. Continue until success or MIN_CHUNK_SIZE reached
422
+ """
423
+ # Find total rows to process
424
+ total_rows = 0
425
+ for arg in arguments.values():
426
+ if isinstance(arg, torch.Tensor):
427
+ total_rows = arg.shape[0]
428
+ break
429
+ elif isinstance(arg, dict):
430
+ for v in arg.values():
431
+ if isinstance(v, torch.Tensor):
432
+ total_rows = v.shape[0]
433
+ break
434
+ if total_rows > 0:
435
+ break
436
+
437
+ if total_rows == 0:
438
+ return task.execute(**arguments)
439
+
440
+ # Calculate initial chunk size
441
+ chunk_size = self._get_adaptive_chunk_size(task, arguments)
442
+
443
+ # FAST PATH: Try to execute all at once if chunk size >= total rows
444
+ if ENABLE_FAST_PATH and chunk_size >= total_rows:
445
+ try:
446
+ gpu_args = {
447
+ k: self._move_tensors(v, self.math_device)
448
+ for k, v in arguments.items()
449
+ }
450
+ res = task.execute(**gpu_args)
451
+ result = self._move_tensors(res, self.storage_device)
452
+ del gpu_args, res
453
+ if ENABLE_AGGRESSIVE_CLEANUP:
454
+ torch.cuda.empty_cache()
455
+ return result
456
+ except torch.OutOfMemoryError:
457
+ LOG.warning(f"Fast path OOM, falling back to chunking")
458
+ torch.cuda.empty_cache()
459
+ gc.collect()
460
+ chunk_size = max(MIN_CHUNK_SIZE, total_rows // 2)
461
+
462
+ # Chunked execution with progressive reduction
463
+ results = []
464
+ i = 0
465
+ oom_count = 0
466
+
467
+ while i < total_rows:
468
+ end = min(i + chunk_size, total_rows)
469
+
470
+ try:
471
+ chunk_args_gpu = {
472
+ k: self._move_tensors(self._slice_argument(v, i, end), self.math_device)
473
+ for k, v in arguments.items()
474
+ }
475
+ chunk_res = task.execute(**chunk_args_gpu)
476
+ results.append(self._move_tensors(chunk_res, self.storage_device))
477
+
478
+ del chunk_args_gpu, chunk_res
479
+
480
+ # Aggressive cleanup per measure.py strategy
481
+ if ENABLE_AGGRESSIVE_CLEANUP:
482
+ torch.cuda.empty_cache()
483
+
484
+ i = end # Move to next chunk
485
+ oom_count = 0 # Reset OOM counter on success
486
+
487
+ except torch.OutOfMemoryError:
488
+ oom_count += 1
489
+ torch.cuda.empty_cache()
490
+ gc.collect()
491
+
492
+ # Progressive reduction
493
+ old_chunk = chunk_size
494
+ chunk_size = max(MIN_CHUNK_SIZE, int(chunk_size * CHUNK_REDUCTION_FACTOR))
495
+
496
+ # Apply power-of-2 alignment
497
+ if ENABLE_POWER_OF_2_ALIGNMENT:
498
+ chunk_size = _round_to_power_of_2(chunk_size, prefer_lower=True)
499
+
500
+ if chunk_size < MIN_CHUNK_SIZE:
501
+ LOG.error(f"Chunk size below minimum ({MIN_CHUNK_SIZE}), cannot continue")
502
+ raise
503
+
504
+ LOG.warning(
505
+ f"OOM at chunk {old_chunk}, reducing to {chunk_size} "
506
+ f"(attempt {oom_count}, progress: {i}/{total_rows})"
507
+ )
508
+
509
+ # Safety: if we OOM too many times, something is wrong
510
+ if oom_count > 10:
511
+ LOG.error("Too many OOM errors, giving up")
512
+ raise
513
+
514
+ # Concatenate results
515
+ if not results:
516
+ return None
517
+
518
+ if isinstance(results[0], torch.Tensor):
519
+ return torch.cat(results, dim=0)
520
+ elif isinstance(results[0], dict):
521
+ out = {}
522
+ for k in results[0].keys():
523
+ out[k] = torch.cat([r[k] for r in results], dim=0)
524
+ return out
525
+
526
+ return results
527
+
528
+ def _execute_with_fallback(self, task: Task, arguments: Dict[str, Any], accelerator) -> Any:
529
+ """
530
+ Execute task with comprehensive fallback strategy.
531
+
532
+ Strategy:
533
+ 1. Try full GPU execution
534
+ 2. Try adaptive chunking
535
+ 3. Try fixed chunk sizes
536
+ 4. Fall back to CPU
537
+ """
538
+ task_name = type(task).__name__
539
+
540
+ # Strategy 1: Try full GPU execution for light tasks
541
+ try:
542
+ gpu_args = {
543
+ k: self._move_tensors(v, self.math_device)
544
+ for k, v in arguments.items()
545
+ }
546
+ res = task.execute(**gpu_args)
547
+ result = self._move_tensors(res, self.storage_device)
548
+ del gpu_args, res
549
+ return result
550
+ except torch.OutOfMemoryError:
551
+ LOG.debug(f"Full GPU execution failed for {task_name}, trying chunked")
552
+ torch.cuda.empty_cache()
553
+ gc.collect()
554
+ except Exception as e:
555
+ LOG.warning(f"GPU execution error for {task_name}: {e}")
556
+ torch.cuda.empty_cache()
557
+ raise
558
+
559
+ # Strategy 2: Try adaptive chunking
560
+ try:
561
+ result = self._execute_chunked(task, arguments)
562
+ return result
563
+ except torch.OutOfMemoryError:
564
+ LOG.warning(f"Adaptive chunking failed for {task_name}, trying fixed sizes")
565
+ torch.cuda.empty_cache()
566
+ gc.collect()
567
+ except Exception as e:
568
+ LOG.warning(f"Chunking error for {task_name}: {e}")
569
+ raise
570
+
571
+ # Strategy 3: Try fixed chunk sizes
572
+ for chunk_size in FALLBACK_CHUNK_SIZES:
573
+ if chunk_size < MIN_CHUNK_SIZE:
574
+ continue
575
+
576
+ try:
577
+ LOG.info(f"Trying fixed chunk size {chunk_size} for {task_name}")
578
+
579
+ # Get total rows
580
+ total_rows = 0
581
+ for arg in arguments.values():
582
+ if isinstance(arg, torch.Tensor):
583
+ total_rows = arg.shape[0]
584
+ break
585
+ elif isinstance(arg, dict):
586
+ for v in arg.values():
587
+ if isinstance(v, torch.Tensor):
588
+ total_rows = v.shape[0]
589
+ break
590
+ if total_rows > 0:
591
+ break
592
+
593
+ if total_rows == 0:
594
+ break
595
+
596
+ results = []
597
+ for i in range(0, total_rows, chunk_size):
598
+ end = min(i + chunk_size, total_rows)
599
+ chunk_args = {
600
+ k: self._slice_argument(v, i, end)
601
+ for k, v in arguments.items()
602
+ }
603
+ chunk_args_gpu = {
604
+ k: self._move_tensors(v, self.math_device)
605
+ for k, v in chunk_args.items()
606
+ }
607
+ chunk_res = task.execute(**chunk_args_gpu)
608
+ results.append(self._move_tensors(chunk_res, self.storage_device))
609
+ del chunk_args, chunk_args_gpu, chunk_res
610
+
611
+ if ENABLE_AGGRESSIVE_CLEANUP:
612
+ torch.cuda.empty_cache()
613
+
614
+ if isinstance(results[0], torch.Tensor):
615
+ return torch.cat(results, dim=0)
616
+ elif isinstance(results[0], dict):
617
+ out = {}
618
+ for k in results[0].keys():
619
+ out[k] = torch.cat([r[k] for r in results], dim=0)
620
+ return out
621
+ return results
622
+
623
+ except torch.OutOfMemoryError:
624
+ torch.cuda.empty_cache()
625
+ gc.collect()
626
+ continue
627
+ except Exception as e:
628
+ LOG.warning(f"Fixed chunk {chunk_size} failed: {e}")
629
+ break
630
+
631
+ # Strategy 4: CPU fallback
632
+ LOG.warning(f"All GPU strategies failed for {task_name}, using CPU")
633
+ raise torch.OutOfMemoryError("Forcing CPU fallback")
634
+
635
+ def _run(
636
+ self,
637
+ quiet: bool = False,
638
+ desc: Optional[str] = None,
639
+ ) -> Iterator[Tuple[TaskHandle, Any]]:
640
+ last_use_index = self.schedule.last_use_index
641
+
642
+ values: Dict[TaskHandle, Any] = {}
643
+ if self.cached_values:
644
+ for task, value in self.cached_values.items():
645
+ values[task] = value
646
+
647
+ is_gpu_execution = self.math_device.type != "cpu"
648
+ accelerator = get_torch_accelerator_module(self.math_device.type) if is_gpu_execution else None
649
+
650
+ for idx, task_handle in (
651
+ pbar := tqdm.tqdm(
652
+ list(enumerate(self.schedule.tasks)),
653
+ disable=quiet,
654
+ desc=desc or "Executing graph",
655
+ )
656
+ ):
657
+ task = task_handle.task()
658
+ task_type = type(task).__name__
659
+
660
+ # Log memory stats periodically
661
+ if is_gpu_execution and idx % 10 == 0:
662
+ stats = self._get_memory_stats()
663
+ LOG.debug(
664
+ f"Memory: {stats.get('allocated_gb', 0):.2f}GB allocated, "
665
+ f"{stats.get('free_gb', 0):.2f}GB free of {stats.get('total_gb', 0):.2f}GB"
666
+ )
667
+
668
+ # Determine execution strategy
669
+ is_cpu_only_task = task_type in CPU_ONLY_TASKS
670
+ want_gpu = is_gpu_execution and task.uses_accelerator() and not is_cpu_only_task
671
+
672
+ # Collect arguments
673
+ arguments = {k: values[h] for k, h in task_handle.arguments().items()}
674
+
675
+ success = False
676
+
677
+ # Try GPU execution
678
+ if want_gpu:
679
+ try:
680
+ res = self._execute_with_fallback(task, arguments, accelerator)
681
+ values[task_handle] = res
682
+ success = True
683
+ except torch.OutOfMemoryError:
684
+ LOG.warning(f"All GPU strategies exhausted for {task_type}, falling back to CPU")
685
+ success = False
686
+ except Exception as e:
687
+ LOG.error(f"GPU execution failed for {task_type}: {e}")
688
+ success = False
689
+
690
+ # Cleanup after GPU attempt
691
+ if is_gpu_execution and ENABLE_AGGRESSIVE_CLEANUP:
692
+ gc.collect()
693
+ if accelerator:
694
+ accelerator.empty_cache()
695
+
696
+ # CPU fallback
697
+ if not success:
698
+ if want_gpu:
699
+ LOG.info(f"Executing {task_type} on CPU")
700
+
701
+ # Ensure cleanup before CPU execution
702
+ if is_gpu_execution:
703
+ gc.collect()
704
+ if accelerator:
705
+ accelerator.empty_cache()
706
+
707
+ # Move arguments to CPU
708
+ cpu_arguments = {
709
+ k: self._move_tensors(v, torch.device("cpu"))
710
+ for k, v in arguments.items()
711
+ }
712
+
713
+ res = task.execute(**cpu_arguments)
714
+ del cpu_arguments
715
+ res = self._move_tensors(res, self.storage_device)
716
+ values[task_handle] = res
717
+
718
+ del res
719
+ del arguments
720
+
721
+ if task_handle in self.targets:
722
+ yield (task_handle, values[task_handle])
723
+
724
+ # Evict unreferenced values
725
+ expired = []
726
+ for key in values:
727
+ if idx >= last_use_index[key]:
728
+ expired.append(key)
729
+ for key in expired:
730
+ del values[key]
731
+
732
+ # Periodic cleanup (measure.py strategy)
733
+ self._task_counter += 1
734
+ if is_gpu_execution and ENABLE_AGGRESSIVE_CLEANUP:
735
+ if CLEANUP_FREQUENCY == 0 or self._task_counter % max(1, CLEANUP_FREQUENCY) == 0:
736
+ gc.collect()
737
+ if accelerator:
738
+ accelerator.empty_cache()
739
+
740
+ del values
741
+ del pbar
742
+
743
+ def run(
744
+ self,
745
+ quiet: bool = False,
746
+ desc: Optional[str] = None,
747
+ ) -> Iterator[Tuple[Task, Any]]:
748
+ for handle, value in self._run(quiet=quiet, desc=desc):
749
+ yield (handle.task(), value)
750
+
751
+ def execute(self, desc: Optional[str] = None) -> None:
752
+ for _ in self.run(desc=desc):
753
+ pass
754
+
755
+ def _move_tensors(
756
+ self, value: Any, device: torch.device, non_blocking: Optional[bool] = None
757
+ ) -> Any:
758
+ """Move tensors to specified device, handling nested structures."""
759
+ if non_blocking is None:
760
+ non_blocking = device.type in ["cuda", "xpu"]
761
+
762
+ if isinstance(value, torch.Tensor):
763
+ if value.device == device:
764
+ return value
765
+ return value.to(device=device, non_blocking=non_blocking)
766
+ elif isinstance(value, dict):
767
+ return {
768
+ k: self._move_tensors(v, device, non_blocking)
769
+ for k, v in value.items()
770
+ }
771
+ elif isinstance(value, list):
772
+ return [self._move_tensors(v, device, non_blocking) for v in value]
773
+ elif isinstance(value, tuple):
774
+ return tuple(self._move_tensors(v, device, non_blocking) for v in value)
775
+
776
+ return value
graph_v18_runpod_A100.py ADDED
@@ -0,0 +1,776 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # graph_v18.py - Optimized for A100 runpod
2
+ # Copyright (C) 2025 Arcee AI
3
+ # SPDX-License-Identifier: LGPL-3.0-only
4
+ """
5
+ Module for computational graph execution.
6
+
7
+ Classes:
8
+ Task: Abstract base class representing a computational task.
9
+ Executor: Class for scheduling and executing directed acyclic task graphs.
10
+ """
11
+
12
+ import os
13
+ import sys
14
+ import gc
15
+ import logging
16
+ import networkx
17
+ import torch
18
+ import tqdm
19
+ from pydantic import BaseModel
20
+ from typing_extensions import Generic, TypeVar
21
+ from abc import ABC, abstractmethod
22
+ from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union
23
+
24
+ from mergekit.common import get_torch_accelerator_module
25
+
26
+ # ============================================================================
27
+ # CONFIGURATION SECTION - TUNE THESE PARAMETERS FOR YOUR GPU
28
+ # ============================================================================
29
+
30
+ # --- PRIMARY VRAM TARGETS ---
31
+ # For 3060 TI (8GB): Start with 7.2-7.4GB. Increase if stable, decrease if OOM.
32
+ # For 3060 (12GB): Try 10.5-11.0GB
33
+ # For 4GB cards: Try 3.2-3.5GB
34
+ TARGET_VRAM_GB = 74 # Target VRAM usage in GB (TUNE THIS FIRST)
35
+
36
+ # Safety margin to account for PyTorch overhead and fragmentation
37
+ # Windows typically needs ~0.8GB, Linux ~0.5GB
38
+ VRAM_SAFETY_MARGIN_GB = 1.0 # Reduce to 0.5-0.6 on Linux, increase to 1.0 if unstable
39
+
40
+ # --- CUDA MEMORY ALLOCATOR CONFIGURATION ---
41
+ # Smaller values = less fragmentation but more overhead
42
+ # 24MB is optimal for 8GB cards, 32MB for 12GB+ cards
43
+ CUDA_MAX_SPLIT_SIZE_MB = 24 # Options: 16, 24, 32, 64
44
+
45
+ # --- CHUNK SIZE BEHAVIOR ---
46
+ # How aggressively to reduce chunk size on OOM (0.5-0.9 range)
47
+ # Lower = more conservative (slower but safer), Higher = more aggressive
48
+ CHUNK_REDUCTION_FACTOR = 0.75 # Options: 0.5 (safe), 0.7 (balanced), 0.85 (aggressive)
49
+
50
+ # Minimum chunk size before giving up and falling back to CPU
51
+ MIN_CHUNK_SIZE = 1 # Usually keep at 1, increase to 4-8 if seeing micro-chunk overhead
52
+
53
+ # Enable power-of-2 alignment for chunk sizes (following measure.py strategy)
54
+ # This improves memory allocation efficiency
55
+ ENABLE_POWER_OF_2_ALIGNMENT = True # Set False if causing issues
56
+
57
+ # --- TASK-SPECIFIC MEMORY MULTIPLIERS ---
58
+ # These control how much extra VRAM to reserve for specific task types
59
+ # Increase if task OOMs, decrease if underutilizing VRAM
60
+ TASK_MULTIPLIERS = {
61
+ "ModelStock": 2.0,
62
+ "Karcher": 3.0,
63
+ "Consensus": 3.0,
64
+ "Prometheus": 6.0, # Forces the 3090 to start with ~8k chunks instead of 65k.
65
+ "default": 1.2,
66
+ }
67
+
68
+ # --- MEMORY CLEANUP BEHAVIOR ---
69
+ # Enable aggressive garbage collection and cache clearing
70
+ # True = slower but more stable, False = faster but may fragment memory
71
+ ENABLE_AGGRESSIVE_CLEANUP = True # Set False if merges are very stable
72
+
73
+ # How often to force cleanup (every N tasks). 0 = after every task
74
+ CLEANUP_FREQUENCY = 2 # Options: 0 (always), 1, 2, 5, 10
75
+
76
+ # --- FALLBACK STRATEGY ---
77
+ # Fixed chunk sizes to try if adaptive chunking fails
78
+ # Powers of 2 work best for GPU memory alignment
79
+ FALLBACK_CHUNK_SIZES = [65536, 32768, 16384, 8192, 4096, 2048, 1024, 512, 256, 128, 64, 32, 16, 8, 4, 2]
80
+
81
+ # --- FAST PATH OPTIMIZATION ---
82
+ # Try to execute entire task at once before chunking
83
+ # True = faster when it works, False = always chunk (more conservative)
84
+ ENABLE_FAST_PATH = True # Set False if getting frequent OOM on large tasks
85
+
86
+ # --- TASK ROUTING ---
87
+ # Tasks that should always run on CPU (typically I/O bound)
88
+ CPU_ONLY_TASKS = [
89
+ "LoadTensor",
90
+ "GatherTensors",
91
+ "SaveTensor",
92
+ "TensorWriterTask",
93
+ "FinalizeModel",
94
+ "PermutedEmbeddings", # Gather operations don't benefit from GPU
95
+ ]
96
+
97
+ # ============================================================================
98
+ # END OF CONFIGURATION SECTION
99
+ # ============================================================================
100
+
101
+ if sys.platform == "win32":
102
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = f"max_split_size_mb:{CUDA_MAX_SPLIT_SIZE_MB}"
103
+
104
+ ValueT = TypeVar("ValueT")
105
+ LOG = logging.getLogger(__name__)
106
+
107
+
108
+ def _round_to_power_of_2(n: int, prefer_lower: bool = True) -> int:
109
+ """Round to nearest power of 2 for memory alignment."""
110
+ if n <= 0:
111
+ return 1
112
+ if n == 1:
113
+ return 1
114
+
115
+ # Find the two nearest powers of 2
116
+ power = n.bit_length() - 1
117
+ lower = 1 << power
118
+ upper = 1 << (power + 1)
119
+
120
+ if prefer_lower or (n - lower) < (upper - n):
121
+ return lower
122
+ return upper
123
+
124
+
125
+ class Task(ABC, BaseModel, Generic[ValueT], frozen=True):
126
+ @abstractmethod
127
+ def arguments(self) -> Dict[str, "Task"]:
128
+ ...
129
+
130
+ @abstractmethod
131
+ def execute(self, **kwargs) -> ValueT:
132
+ ...
133
+
134
+ def priority(self) -> int:
135
+ return 0
136
+
137
+ def group_label(self) -> Optional[str]:
138
+ return None
139
+
140
+ def uses_accelerator(self) -> bool:
141
+ return False
142
+
143
+ def main_thread_only(self) -> bool:
144
+ return False
145
+
146
+ def duplicate_per_gpu(self) -> bool:
147
+ return False
148
+
149
+
150
+ class TaskUniverse:
151
+ tasks: List[Task]
152
+ task_to_index: Dict[Task, int]
153
+ task_arguments: Dict[int, Dict[str, int]]
154
+ _type_id_to_index: Dict[Tuple[type, int], int]
155
+
156
+ def __init__(self, tasks: Optional[Iterable[Task]] = None):
157
+ self.tasks = []
158
+ self.task_to_index = {}
159
+ self.task_arguments = {}
160
+ self._type_id_to_index = {}
161
+ if tasks is not None:
162
+ for task in tasks:
163
+ self.add_task(task)
164
+
165
+ def add_task(self, task: Task, recursive: bool = True) -> "TaskHandle":
166
+ _ti_key = (type(task), id(task))
167
+ if _ti_key in self._type_id_to_index:
168
+ index = self._type_id_to_index[_ti_key]
169
+ return TaskHandle(self, index)
170
+
171
+ index = self.task_to_index.setdefault(task, len(self.tasks))
172
+ if index < len(self.tasks):
173
+ return TaskHandle(self, index)
174
+ self.tasks.append(task)
175
+ self._type_id_to_index[_ti_key] = index
176
+
177
+ if recursive:
178
+ self.task_arguments[index] = {}
179
+ for k, v in task.arguments().items():
180
+ self.task_arguments[index][k] = self.add_task(v, recursive=True)._index
181
+ return TaskHandle(self, index)
182
+
183
+ def get_handle(self, task: Task) -> Optional["TaskHandle"]:
184
+ if task not in self.task_to_index:
185
+ return None
186
+ return TaskHandle(self, self.task_to_index[task])
187
+
188
+
189
+ class TaskHandle:
190
+ __slots__ = ["_universe", "_index"]
191
+ _universe: TaskUniverse
192
+ _index: int
193
+
194
+ def __init__(self, universe: TaskUniverse, index: int):
195
+ self._universe = universe
196
+ self._index = index
197
+
198
+ def task(self) -> Task:
199
+ return self._universe.tasks[self._index]
200
+
201
+ def arguments(self) -> Dict[str, "TaskHandle"]:
202
+ return {
203
+ k: TaskHandle(self._universe, v)
204
+ for k, v in self._universe.task_arguments[self._index].items()
205
+ }
206
+
207
+ def __eq__(self, other):
208
+ if not isinstance(other, TaskHandle):
209
+ return False
210
+ return self._index == other._index and self._universe is other._universe
211
+
212
+ def __hash__(self):
213
+ return self._index
214
+
215
+ def __str__(self):
216
+ return f"TaskHandle({type(self.task()).__name__}, {self._index})"
217
+
218
+ __repr__ = __str__
219
+
220
+
221
+ class ExecutionSchedule:
222
+ tasks: List[TaskHandle]
223
+ last_use_index: Dict[TaskHandle, int]
224
+
225
+ def __init__(self, tasks: List[TaskHandle], last_use_index: Dict[TaskHandle, int]):
226
+ self.tasks = tasks
227
+ self.last_use_index = last_use_index
228
+
229
+
230
+ def build_schedule(
231
+ targets: List[TaskHandle], cached_values: Dict[TaskHandle, Any]
232
+ ) -> ExecutionSchedule:
233
+ if not targets:
234
+ return ExecutionSchedule(tasks=[], last_use_index={})
235
+
236
+ universe = targets[0]._universe
237
+ dummy_handle = TaskHandle(universe, -1)
238
+ edge_tups: List[Tuple[TaskHandle, TaskHandle]] = []
239
+
240
+ explored = set()
241
+ to_explore = set(targets)
242
+ while to_explore:
243
+ task = to_explore.pop()
244
+ if task in explored:
245
+ continue
246
+ explored.add(task)
247
+ if task in (cached_values or {}):
248
+ continue
249
+ for dep in task.arguments().values():
250
+ to_explore.add(dep)
251
+ edge_tups.append((dep, task))
252
+
253
+ for target in targets:
254
+ edge_tups.append((dummy_handle, target))
255
+
256
+ def _compare_key(node: TaskHandle) -> Tuple[str, int]:
257
+ if node._index < 0:
258
+ return ("", 0)
259
+ task = node.task()
260
+ return (task.group_label() or "", -task.priority())
261
+
262
+ graph = networkx.DiGraph(edge_tups)
263
+ schedule: List[TaskHandle] = [
264
+ node
265
+ for node in networkx.lexicographical_topological_sort(graph, key=_compare_key)
266
+ if (node != dummy_handle) and node not in (cached_values or {})
267
+ ]
268
+
269
+ last_use_index = {}
270
+ for idx, task in reversed(list(enumerate(schedule))):
271
+ for dep in task.arguments().values():
272
+ if dep not in last_use_index:
273
+ last_use_index[dep] = idx
274
+ if task not in last_use_index:
275
+ last_use_index[task] = idx
276
+ for task in cached_values or {}:
277
+ if task not in last_use_index:
278
+ last_use_index[task] = len(schedule) + 1
279
+
280
+ return ExecutionSchedule(tasks=schedule, last_use_index=last_use_index)
281
+
282
+
283
+ class Executor:
284
+ math_device: torch.device
285
+ storage_device: torch.device
286
+ universe: TaskUniverse
287
+ targets: List[TaskHandle]
288
+ schedule: ExecutionSchedule
289
+ cached_values: Optional[Dict[TaskHandle, Any]]
290
+ _task_counter: int
291
+
292
+ def __init__(
293
+ self,
294
+ targets: Union[List[Task], List[TaskHandle]],
295
+ math_device: torch.device = torch.device("cpu"),
296
+ storage_device: torch.device = torch.device("cpu"),
297
+ cached_values: Optional[Dict[TaskHandle, Any]] = None,
298
+ ):
299
+ self.cached_values = cached_values
300
+ self._task_counter = 0
301
+
302
+ if isinstance(math_device, str):
303
+ math_device = torch.device(math_device)
304
+ if isinstance(storage_device, str):
305
+ storage_device = torch.device(storage_device)
306
+ self.math_device = math_device
307
+ self.storage_device = storage_device
308
+
309
+ if targets and isinstance(targets[0], Task):
310
+ universe = TaskUniverse(targets)
311
+ targets = [universe.add_task(t) for t in targets]
312
+ elif targets and isinstance(targets[0], TaskHandle):
313
+ universe = targets[0]._universe
314
+ elif not targets:
315
+ universe = TaskUniverse()
316
+ else:
317
+ raise ValueError("Targets must be a list of Task or TaskHandle instances")
318
+
319
+ self.universe = universe
320
+ self.targets = targets
321
+ self.schedule = build_schedule(targets, cached_values=cached_values)
322
+
323
+ def _slice_argument(self, arg: Any, start: int, end: int) -> Any:
324
+ """Recursively slice tensors within nested structures."""
325
+ if isinstance(arg, torch.Tensor):
326
+ if arg.shape[0] > 1:
327
+ return arg[start:end]
328
+ return arg
329
+ elif isinstance(arg, dict):
330
+ return {k: self._slice_argument(v, start, end) for k, v in arg.items()}
331
+ elif isinstance(arg, list):
332
+ return [self._slice_argument(v, start, end) for v in arg]
333
+ elif isinstance(arg, tuple):
334
+ return tuple(self._slice_argument(v, start, end) for v in arg)
335
+ return arg
336
+
337
+ def _get_memory_stats(self) -> Dict[str, float]:
338
+ """Get current VRAM statistics in GB."""
339
+ if self.math_device.type != "cuda":
340
+ return {}
341
+
342
+ allocated = torch.cuda.memory_allocated(self.math_device) / (1024**3)
343
+ reserved = torch.cuda.memory_reserved(self.math_device) / (1024**3)
344
+ total = torch.cuda.get_device_properties(self.math_device).total_memory / (1024**3)
345
+
346
+ return {
347
+ "allocated_gb": allocated,
348
+ "reserved_gb": reserved,
349
+ "total_gb": total,
350
+ "free_gb": total - allocated,
351
+ }
352
+
353
+ def _get_adaptive_chunk_size(self, task: Task, arguments: Dict[str, Any]) -> int:
354
+ """
355
+ Calculate optimal chunk size based on available VRAM and task requirements.
356
+
357
+ This implements the "measure.py strategy" of targeting a specific VRAM fill level
358
+ rather than using currently available memory, which prevents oscillation.
359
+ """
360
+ if self.math_device.type == "cpu":
361
+ return 1024 # Large default for CPU
362
+
363
+ # Get hardware capacity
364
+ total_vram = torch.cuda.get_device_properties(self.math_device).total_memory
365
+ target_bytes = TARGET_VRAM_GB * (1024**3)
366
+
367
+ # Analyze tensor dimensions and count
368
+ num_tensors = 0
369
+ width = 0
370
+ bytes_per_element = 4 # Default float32
371
+
372
+ for arg in arguments.values():
373
+ if isinstance(arg, torch.Tensor):
374
+ num_tensors += 1
375
+ width = max(width, arg.shape[-1] if len(arg.shape) > 1 else arg.shape[0])
376
+ bytes_per_element = arg.element_size()
377
+ elif isinstance(arg, dict):
378
+ for v in arg.values():
379
+ if isinstance(v, torch.Tensor):
380
+ num_tensors += 1
381
+ width = max(width, v.shape[-1] if len(v.shape) > 1 else v.shape[0])
382
+ bytes_per_element = v.element_size()
383
+
384
+ if num_tensors == 0 or width == 0:
385
+ return 512 # Safe default
386
+
387
+ # Get task-specific multiplier
388
+ task_name = type(task).__name__
389
+ multiplier = TASK_MULTIPLIERS.get("default", 1.2)
390
+
391
+ for key, mult in TASK_MULTIPLIERS.items():
392
+ if key in task_name:
393
+ multiplier = mult
394
+ break
395
+
396
+ # Calculate bytes per row with multiplier for working memory
397
+ bytes_per_row = num_tensors * width * bytes_per_element * multiplier
398
+
399
+ # Calculate usable VRAM (target minus current allocation and safety margin)
400
+ current_allocated = torch.cuda.memory_allocated(self.math_device)
401
+ safety_bytes = VRAM_SAFETY_MARGIN_GB * (1024**3)
402
+ usable_vram = max(target_bytes - current_allocated - safety_bytes, 1024 * (1024**2))
403
+
404
+ # Calculate chunk size
405
+ chunk_size = max(MIN_CHUNK_SIZE, int(usable_vram // bytes_per_row))
406
+
407
+ # Apply power-of-2 alignment if enabled (measure.py strategy)
408
+ if ENABLE_POWER_OF_2_ALIGNMENT and chunk_size > MIN_CHUNK_SIZE:
409
+ chunk_size = _round_to_power_of_2(chunk_size, prefer_lower=True)
410
+
411
+ LOG.debug(f"Calculated chunk size: {chunk_size} (tensors={num_tensors}, width={width}, mult={multiplier:.2f})")
412
+ return chunk_size
413
+
414
+ def _execute_chunked(self, task: Task, arguments: Dict[str, Any]) -> Any:
415
+ """
416
+ Execute task in chunks with progressive fallback strategy.
417
+
418
+ Strategy:
419
+ 1. Try adaptive chunk size
420
+ 2. On OOM, reduce by CHUNK_REDUCTION_FACTOR
421
+ 3. Continue until success or MIN_CHUNK_SIZE reached
422
+ """
423
+ # Find total rows to process
424
+ total_rows = 0
425
+ for arg in arguments.values():
426
+ if isinstance(arg, torch.Tensor):
427
+ total_rows = arg.shape[0]
428
+ break
429
+ elif isinstance(arg, dict):
430
+ for v in arg.values():
431
+ if isinstance(v, torch.Tensor):
432
+ total_rows = v.shape[0]
433
+ break
434
+ if total_rows > 0:
435
+ break
436
+
437
+ if total_rows == 0:
438
+ return task.execute(**arguments)
439
+
440
+ # Calculate initial chunk size
441
+ chunk_size = self._get_adaptive_chunk_size(task, arguments)
442
+
443
+ # FAST PATH: Try to execute all at once if chunk size >= total rows
444
+ if ENABLE_FAST_PATH and chunk_size >= total_rows:
445
+ try:
446
+ gpu_args = {
447
+ k: self._move_tensors(v, self.math_device)
448
+ for k, v in arguments.items()
449
+ }
450
+ res = task.execute(**gpu_args)
451
+ result = self._move_tensors(res, self.storage_device)
452
+ del gpu_args, res
453
+ if ENABLE_AGGRESSIVE_CLEANUP:
454
+ torch.cuda.empty_cache()
455
+ return result
456
+ except torch.OutOfMemoryError:
457
+ LOG.warning(f"Fast path OOM, falling back to chunking")
458
+ torch.cuda.empty_cache()
459
+ gc.collect()
460
+ chunk_size = max(MIN_CHUNK_SIZE, total_rows // 2)
461
+
462
+ # Chunked execution with progressive reduction
463
+ results = []
464
+ i = 0
465
+ oom_count = 0
466
+
467
+ while i < total_rows:
468
+ end = min(i + chunk_size, total_rows)
469
+
470
+ try:
471
+ chunk_args_gpu = {
472
+ k: self._move_tensors(self._slice_argument(v, i, end), self.math_device)
473
+ for k, v in arguments.items()
474
+ }
475
+ chunk_res = task.execute(**chunk_args_gpu)
476
+ results.append(self._move_tensors(chunk_res, self.storage_device))
477
+
478
+ del chunk_args_gpu, chunk_res
479
+
480
+ # Aggressive cleanup per measure.py strategy
481
+ if ENABLE_AGGRESSIVE_CLEANUP:
482
+ torch.cuda.empty_cache()
483
+
484
+ i = end # Move to next chunk
485
+ oom_count = 0 # Reset OOM counter on success
486
+
487
+ except torch.OutOfMemoryError:
488
+ oom_count += 1
489
+ torch.cuda.empty_cache()
490
+ gc.collect()
491
+
492
+ # Progressive reduction
493
+ old_chunk = chunk_size
494
+ chunk_size = max(MIN_CHUNK_SIZE, int(chunk_size * CHUNK_REDUCTION_FACTOR))
495
+
496
+ # Apply power-of-2 alignment
497
+ if ENABLE_POWER_OF_2_ALIGNMENT:
498
+ chunk_size = _round_to_power_of_2(chunk_size, prefer_lower=True)
499
+
500
+ if chunk_size < MIN_CHUNK_SIZE:
501
+ LOG.error(f"Chunk size below minimum ({MIN_CHUNK_SIZE}), cannot continue")
502
+ raise
503
+
504
+ LOG.warning(
505
+ f"OOM at chunk {old_chunk}, reducing to {chunk_size} "
506
+ f"(attempt {oom_count}, progress: {i}/{total_rows})"
507
+ )
508
+
509
+ # Safety: if we OOM too many times, something is wrong
510
+ if oom_count > 10:
511
+ LOG.error("Too many OOM errors, giving up")
512
+ raise
513
+
514
+ # Concatenate results
515
+ if not results:
516
+ return None
517
+
518
+ if isinstance(results[0], torch.Tensor):
519
+ return torch.cat(results, dim=0)
520
+ elif isinstance(results[0], dict):
521
+ out = {}
522
+ for k in results[0].keys():
523
+ out[k] = torch.cat([r[k] for r in results], dim=0)
524
+ return out
525
+
526
+ return results
527
+
528
+ def _execute_with_fallback(self, task: Task, arguments: Dict[str, Any], accelerator) -> Any:
529
+ """
530
+ Execute task with comprehensive fallback strategy.
531
+
532
+ Strategy:
533
+ 1. Try full GPU execution
534
+ 2. Try adaptive chunking
535
+ 3. Try fixed chunk sizes
536
+ 4. Fall back to CPU
537
+ """
538
+ task_name = type(task).__name__
539
+
540
+ # Strategy 1: Try full GPU execution for light tasks
541
+ try:
542
+ gpu_args = {
543
+ k: self._move_tensors(v, self.math_device)
544
+ for k, v in arguments.items()
545
+ }
546
+ res = task.execute(**gpu_args)
547
+ result = self._move_tensors(res, self.storage_device)
548
+ del gpu_args, res
549
+ return result
550
+ except torch.OutOfMemoryError:
551
+ LOG.debug(f"Full GPU execution failed for {task_name}, trying chunked")
552
+ torch.cuda.empty_cache()
553
+ gc.collect()
554
+ except Exception as e:
555
+ LOG.warning(f"GPU execution error for {task_name}: {e}")
556
+ torch.cuda.empty_cache()
557
+ raise
558
+
559
+ # Strategy 2: Try adaptive chunking
560
+ try:
561
+ result = self._execute_chunked(task, arguments)
562
+ return result
563
+ except torch.OutOfMemoryError:
564
+ LOG.warning(f"Adaptive chunking failed for {task_name}, trying fixed sizes")
565
+ torch.cuda.empty_cache()
566
+ gc.collect()
567
+ except Exception as e:
568
+ LOG.warning(f"Chunking error for {task_name}: {e}")
569
+ raise
570
+
571
+ # Strategy 3: Try fixed chunk sizes
572
+ for chunk_size in FALLBACK_CHUNK_SIZES:
573
+ if chunk_size < MIN_CHUNK_SIZE:
574
+ continue
575
+
576
+ try:
577
+ LOG.info(f"Trying fixed chunk size {chunk_size} for {task_name}")
578
+
579
+ # Get total rows
580
+ total_rows = 0
581
+ for arg in arguments.values():
582
+ if isinstance(arg, torch.Tensor):
583
+ total_rows = arg.shape[0]
584
+ break
585
+ elif isinstance(arg, dict):
586
+ for v in arg.values():
587
+ if isinstance(v, torch.Tensor):
588
+ total_rows = v.shape[0]
589
+ break
590
+ if total_rows > 0:
591
+ break
592
+
593
+ if total_rows == 0:
594
+ break
595
+
596
+ results = []
597
+ for i in range(0, total_rows, chunk_size):
598
+ end = min(i + chunk_size, total_rows)
599
+ chunk_args = {
600
+ k: self._slice_argument(v, i, end)
601
+ for k, v in arguments.items()
602
+ }
603
+ chunk_args_gpu = {
604
+ k: self._move_tensors(v, self.math_device)
605
+ for k, v in chunk_args.items()
606
+ }
607
+ chunk_res = task.execute(**chunk_args_gpu)
608
+ results.append(self._move_tensors(chunk_res, self.storage_device))
609
+ del chunk_args, chunk_args_gpu, chunk_res
610
+
611
+ if ENABLE_AGGRESSIVE_CLEANUP:
612
+ torch.cuda.empty_cache()
613
+
614
+ if isinstance(results[0], torch.Tensor):
615
+ return torch.cat(results, dim=0)
616
+ elif isinstance(results[0], dict):
617
+ out = {}
618
+ for k in results[0].keys():
619
+ out[k] = torch.cat([r[k] for r in results], dim=0)
620
+ return out
621
+ return results
622
+
623
+ except torch.OutOfMemoryError:
624
+ torch.cuda.empty_cache()
625
+ gc.collect()
626
+ continue
627
+ except Exception as e:
628
+ LOG.warning(f"Fixed chunk {chunk_size} failed: {e}")
629
+ break
630
+
631
+ # Strategy 4: CPU fallback
632
+ LOG.warning(f"All GPU strategies failed for {task_name}, using CPU")
633
+ raise torch.OutOfMemoryError("Forcing CPU fallback")
634
+
635
+ def _run(
636
+ self,
637
+ quiet: bool = False,
638
+ desc: Optional[str] = None,
639
+ ) -> Iterator[Tuple[TaskHandle, Any]]:
640
+ last_use_index = self.schedule.last_use_index
641
+
642
+ values: Dict[TaskHandle, Any] = {}
643
+ if self.cached_values:
644
+ for task, value in self.cached_values.items():
645
+ values[task] = value
646
+
647
+ is_gpu_execution = self.math_device.type != "cpu"
648
+ accelerator = get_torch_accelerator_module(self.math_device.type) if is_gpu_execution else None
649
+
650
+ for idx, task_handle in (
651
+ pbar := tqdm.tqdm(
652
+ list(enumerate(self.schedule.tasks)),
653
+ disable=quiet,
654
+ desc=desc or "Executing graph",
655
+ )
656
+ ):
657
+ task = task_handle.task()
658
+ task_type = type(task).__name__
659
+
660
+ # Log memory stats periodically
661
+ if is_gpu_execution and idx % 10 == 0:
662
+ stats = self._get_memory_stats()
663
+ LOG.debug(
664
+ f"Memory: {stats.get('allocated_gb', 0):.2f}GB allocated, "
665
+ f"{stats.get('free_gb', 0):.2f}GB free of {stats.get('total_gb', 0):.2f}GB"
666
+ )
667
+
668
+ # Determine execution strategy
669
+ is_cpu_only_task = task_type in CPU_ONLY_TASKS
670
+ want_gpu = is_gpu_execution and task.uses_accelerator() and not is_cpu_only_task
671
+
672
+ # Collect arguments
673
+ arguments = {k: values[h] for k, h in task_handle.arguments().items()}
674
+
675
+ success = False
676
+
677
+ # Try GPU execution
678
+ if want_gpu:
679
+ try:
680
+ res = self._execute_with_fallback(task, arguments, accelerator)
681
+ values[task_handle] = res
682
+ success = True
683
+ except torch.OutOfMemoryError:
684
+ LOG.warning(f"All GPU strategies exhausted for {task_type}, falling back to CPU")
685
+ success = False
686
+ except Exception as e:
687
+ LOG.error(f"GPU execution failed for {task_type}: {e}")
688
+ success = False
689
+
690
+ # Cleanup after GPU attempt
691
+ if is_gpu_execution and ENABLE_AGGRESSIVE_CLEANUP:
692
+ gc.collect()
693
+ if accelerator:
694
+ accelerator.empty_cache()
695
+
696
+ # CPU fallback
697
+ if not success:
698
+ if want_gpu:
699
+ LOG.info(f"Executing {task_type} on CPU")
700
+
701
+ # Ensure cleanup before CPU execution
702
+ if is_gpu_execution:
703
+ gc.collect()
704
+ if accelerator:
705
+ accelerator.empty_cache()
706
+
707
+ # Move arguments to CPU
708
+ cpu_arguments = {
709
+ k: self._move_tensors(v, torch.device("cpu"))
710
+ for k, v in arguments.items()
711
+ }
712
+
713
+ res = task.execute(**cpu_arguments)
714
+ del cpu_arguments
715
+ res = self._move_tensors(res, self.storage_device)
716
+ values[task_handle] = res
717
+
718
+ del res
719
+ del arguments
720
+
721
+ if task_handle in self.targets:
722
+ yield (task_handle, values[task_handle])
723
+
724
+ # Evict unreferenced values
725
+ expired = []
726
+ for key in values:
727
+ if idx >= last_use_index[key]:
728
+ expired.append(key)
729
+ for key in expired:
730
+ del values[key]
731
+
732
+ # Periodic cleanup (measure.py strategy)
733
+ self._task_counter += 1
734
+ if is_gpu_execution and ENABLE_AGGRESSIVE_CLEANUP:
735
+ if CLEANUP_FREQUENCY == 0 or self._task_counter % max(1, CLEANUP_FREQUENCY) == 0:
736
+ gc.collect()
737
+ if accelerator:
738
+ accelerator.empty_cache()
739
+
740
+ del values
741
+ del pbar
742
+
743
+ def run(
744
+ self,
745
+ quiet: bool = False,
746
+ desc: Optional[str] = None,
747
+ ) -> Iterator[Tuple[Task, Any]]:
748
+ for handle, value in self._run(quiet=quiet, desc=desc):
749
+ yield (handle.task(), value)
750
+
751
+ def execute(self, desc: Optional[str] = None) -> None:
752
+ for _ in self.run(desc=desc):
753
+ pass
754
+
755
+ def _move_tensors(
756
+ self, value: Any, device: torch.device, non_blocking: Optional[bool] = None
757
+ ) -> Any:
758
+ """Move tensors to specified device, handling nested structures."""
759
+ if non_blocking is None:
760
+ non_blocking = device.type in ["cuda", "xpu"]
761
+
762
+ if isinstance(value, torch.Tensor):
763
+ if value.device == device:
764
+ return value
765
+ return value.to(device=device, non_blocking=non_blocking)
766
+ elif isinstance(value, dict):
767
+ return {
768
+ k: self._move_tensors(v, device, non_blocking)
769
+ for k, v in value.items()
770
+ }
771
+ elif isinstance(value, list):
772
+ return [self._move_tensors(v, device, non_blocking) for v in value]
773
+ elif isinstance(value, tuple):
774
+ return tuple(self._move_tensors(v, device, non_blocking) for v in value)
775
+
776
+ return value
graph_v18_runpod_A40.py ADDED
@@ -0,0 +1,776 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # graph_v18.py - Optimized for A40 runpod
2
+ # Copyright (C) 2025 Arcee AI
3
+ # SPDX-License-Identifier: LGPL-3.0-only
4
+ """
5
+ Module for computational graph execution.
6
+
7
+ Classes:
8
+ Task: Abstract base class representing a computational task.
9
+ Executor: Class for scheduling and executing directed acyclic task graphs.
10
+ """
11
+
12
+ import os
13
+ import sys
14
+ import gc
15
+ import logging
16
+ import networkx
17
+ import torch
18
+ import tqdm
19
+ from pydantic import BaseModel
20
+ from typing_extensions import Generic, TypeVar
21
+ from abc import ABC, abstractmethod
22
+ from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union
23
+
24
+ from mergekit.common import get_torch_accelerator_module
25
+
26
+ # ============================================================================
27
+ # CONFIGURATION SECTION - TUNE THESE PARAMETERS FOR YOUR GPU
28
+ # ============================================================================
29
+
30
+ # --- PRIMARY VRAM TARGETS ---
31
+ # For 3060 TI (8GB): Start with 7.2-7.4GB. Increase if stable, decrease if OOM.
32
+ # For 3060 (12GB): Try 10.5-11.0GB
33
+ # For 4GB cards: Try 3.2-3.5GB
34
+ TARGET_VRAM_GB = 47 # Target VRAM usage in GB (TUNE THIS FIRST)
35
+
36
+ # Safety margin to account for PyTorch overhead and fragmentation
37
+ # Windows typically needs ~0.8GB, Linux ~0.5GB
38
+ VRAM_SAFETY_MARGIN_GB = 1.0 # Reduce to 0.5-0.6 on Linux, increase to 1.0 if unstable
39
+
40
+ # --- CUDA MEMORY ALLOCATOR CONFIGURATION ---
41
+ # Smaller values = less fragmentation but more overhead
42
+ # 24MB is optimal for 8GB cards, 32MB for 12GB+ cards
43
+ CUDA_MAX_SPLIT_SIZE_MB = 24 # Options: 16, 24, 32, 64
44
+
45
+ # --- CHUNK SIZE BEHAVIOR ---
46
+ # How aggressively to reduce chunk size on OOM (0.5-0.9 range)
47
+ # Lower = more conservative (slower but safer), Higher = more aggressive
48
+ CHUNK_REDUCTION_FACTOR = 0.75 # Options: 0.5 (safe), 0.7 (balanced), 0.85 (aggressive)
49
+
50
+ # Minimum chunk size before giving up and falling back to CPU
51
+ MIN_CHUNK_SIZE = 1 # Usually keep at 1, increase to 4-8 if seeing micro-chunk overhead
52
+
53
+ # Enable power-of-2 alignment for chunk sizes (following measure.py strategy)
54
+ # This improves memory allocation efficiency
55
+ ENABLE_POWER_OF_2_ALIGNMENT = True # Set False if causing issues
56
+
57
+ # --- TASK-SPECIFIC MEMORY MULTIPLIERS ---
58
+ # These control how much extra VRAM to reserve for specific task types
59
+ # Increase if task OOMs, decrease if underutilizing VRAM
60
+ TASK_MULTIPLIERS = {
61
+ "ModelStock": 2.0,
62
+ "Karcher": 3.0,
63
+ "Consensus": 3.0,
64
+ "Prometheus": 6.0, # Forces the 3090 to start with ~8k chunks instead of 65k.
65
+ "default": 1.2,
66
+ }
67
+
68
+ # --- MEMORY CLEANUP BEHAVIOR ---
69
+ # Enable aggressive garbage collection and cache clearing
70
+ # True = slower but more stable, False = faster but may fragment memory
71
+ ENABLE_AGGRESSIVE_CLEANUP = True # Set False if merges are very stable
72
+
73
+ # How often to force cleanup (every N tasks). 0 = after every task
74
+ CLEANUP_FREQUENCY = 2 # Options: 0 (always), 1, 2, 5, 10
75
+
76
+ # --- FALLBACK STRATEGY ---
77
+ # Fixed chunk sizes to try if adaptive chunking fails
78
+ # Powers of 2 work best for GPU memory alignment
79
+ FALLBACK_CHUNK_SIZES = [32768, 16384, 8192, 4096, 2048, 1024, 512, 256, 128, 64, 32, 16, 8, 4, 2]
80
+
81
+ # --- FAST PATH OPTIMIZATION ---
82
+ # Try to execute entire task at once before chunking
83
+ # True = faster when it works, False = always chunk (more conservative)
84
+ ENABLE_FAST_PATH = True # Set False if getting frequent OOM on large tasks
85
+
86
+ # --- TASK ROUTING ---
87
+ # Tasks that should always run on CPU (typically I/O bound)
88
+ CPU_ONLY_TASKS = [
89
+ "LoadTensor",
90
+ "GatherTensors",
91
+ "SaveTensor",
92
+ "TensorWriterTask",
93
+ "FinalizeModel",
94
+ "PermutedEmbeddings", # Gather operations don't benefit from GPU
95
+ ]
96
+
97
+ # ============================================================================
98
+ # END OF CONFIGURATION SECTION
99
+ # ============================================================================
100
+
101
+ if sys.platform == "win32":
102
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = f"max_split_size_mb:{CUDA_MAX_SPLIT_SIZE_MB}"
103
+
104
+ ValueT = TypeVar("ValueT")
105
+ LOG = logging.getLogger(__name__)
106
+
107
+
108
+ def _round_to_power_of_2(n: int, prefer_lower: bool = True) -> int:
109
+ """Round to nearest power of 2 for memory alignment."""
110
+ if n <= 0:
111
+ return 1
112
+ if n == 1:
113
+ return 1
114
+
115
+ # Find the two nearest powers of 2
116
+ power = n.bit_length() - 1
117
+ lower = 1 << power
118
+ upper = 1 << (power + 1)
119
+
120
+ if prefer_lower or (n - lower) < (upper - n):
121
+ return lower
122
+ return upper
123
+
124
+
125
+ class Task(ABC, BaseModel, Generic[ValueT], frozen=True):
126
+ @abstractmethod
127
+ def arguments(self) -> Dict[str, "Task"]:
128
+ ...
129
+
130
+ @abstractmethod
131
+ def execute(self, **kwargs) -> ValueT:
132
+ ...
133
+
134
+ def priority(self) -> int:
135
+ return 0
136
+
137
+ def group_label(self) -> Optional[str]:
138
+ return None
139
+
140
+ def uses_accelerator(self) -> bool:
141
+ return False
142
+
143
+ def main_thread_only(self) -> bool:
144
+ return False
145
+
146
+ def duplicate_per_gpu(self) -> bool:
147
+ return False
148
+
149
+
150
+ class TaskUniverse:
151
+ tasks: List[Task]
152
+ task_to_index: Dict[Task, int]
153
+ task_arguments: Dict[int, Dict[str, int]]
154
+ _type_id_to_index: Dict[Tuple[type, int], int]
155
+
156
+ def __init__(self, tasks: Optional[Iterable[Task]] = None):
157
+ self.tasks = []
158
+ self.task_to_index = {}
159
+ self.task_arguments = {}
160
+ self._type_id_to_index = {}
161
+ if tasks is not None:
162
+ for task in tasks:
163
+ self.add_task(task)
164
+
165
+ def add_task(self, task: Task, recursive: bool = True) -> "TaskHandle":
166
+ _ti_key = (type(task), id(task))
167
+ if _ti_key in self._type_id_to_index:
168
+ index = self._type_id_to_index[_ti_key]
169
+ return TaskHandle(self, index)
170
+
171
+ index = self.task_to_index.setdefault(task, len(self.tasks))
172
+ if index < len(self.tasks):
173
+ return TaskHandle(self, index)
174
+ self.tasks.append(task)
175
+ self._type_id_to_index[_ti_key] = index
176
+
177
+ if recursive:
178
+ self.task_arguments[index] = {}
179
+ for k, v in task.arguments().items():
180
+ self.task_arguments[index][k] = self.add_task(v, recursive=True)._index
181
+ return TaskHandle(self, index)
182
+
183
+ def get_handle(self, task: Task) -> Optional["TaskHandle"]:
184
+ if task not in self.task_to_index:
185
+ return None
186
+ return TaskHandle(self, self.task_to_index[task])
187
+
188
+
189
+ class TaskHandle:
190
+ __slots__ = ["_universe", "_index"]
191
+ _universe: TaskUniverse
192
+ _index: int
193
+
194
+ def __init__(self, universe: TaskUniverse, index: int):
195
+ self._universe = universe
196
+ self._index = index
197
+
198
+ def task(self) -> Task:
199
+ return self._universe.tasks[self._index]
200
+
201
+ def arguments(self) -> Dict[str, "TaskHandle"]:
202
+ return {
203
+ k: TaskHandle(self._universe, v)
204
+ for k, v in self._universe.task_arguments[self._index].items()
205
+ }
206
+
207
+ def __eq__(self, other):
208
+ if not isinstance(other, TaskHandle):
209
+ return False
210
+ return self._index == other._index and self._universe is other._universe
211
+
212
+ def __hash__(self):
213
+ return self._index
214
+
215
+ def __str__(self):
216
+ return f"TaskHandle({type(self.task()).__name__}, {self._index})"
217
+
218
+ __repr__ = __str__
219
+
220
+
221
+ class ExecutionSchedule:
222
+ tasks: List[TaskHandle]
223
+ last_use_index: Dict[TaskHandle, int]
224
+
225
+ def __init__(self, tasks: List[TaskHandle], last_use_index: Dict[TaskHandle, int]):
226
+ self.tasks = tasks
227
+ self.last_use_index = last_use_index
228
+
229
+
230
+ def build_schedule(
231
+ targets: List[TaskHandle], cached_values: Dict[TaskHandle, Any]
232
+ ) -> ExecutionSchedule:
233
+ if not targets:
234
+ return ExecutionSchedule(tasks=[], last_use_index={})
235
+
236
+ universe = targets[0]._universe
237
+ dummy_handle = TaskHandle(universe, -1)
238
+ edge_tups: List[Tuple[TaskHandle, TaskHandle]] = []
239
+
240
+ explored = set()
241
+ to_explore = set(targets)
242
+ while to_explore:
243
+ task = to_explore.pop()
244
+ if task in explored:
245
+ continue
246
+ explored.add(task)
247
+ if task in (cached_values or {}):
248
+ continue
249
+ for dep in task.arguments().values():
250
+ to_explore.add(dep)
251
+ edge_tups.append((dep, task))
252
+
253
+ for target in targets:
254
+ edge_tups.append((dummy_handle, target))
255
+
256
+ def _compare_key(node: TaskHandle) -> Tuple[str, int]:
257
+ if node._index < 0:
258
+ return ("", 0)
259
+ task = node.task()
260
+ return (task.group_label() or "", -task.priority())
261
+
262
+ graph = networkx.DiGraph(edge_tups)
263
+ schedule: List[TaskHandle] = [
264
+ node
265
+ for node in networkx.lexicographical_topological_sort(graph, key=_compare_key)
266
+ if (node != dummy_handle) and node not in (cached_values or {})
267
+ ]
268
+
269
+ last_use_index = {}
270
+ for idx, task in reversed(list(enumerate(schedule))):
271
+ for dep in task.arguments().values():
272
+ if dep not in last_use_index:
273
+ last_use_index[dep] = idx
274
+ if task not in last_use_index:
275
+ last_use_index[task] = idx
276
+ for task in cached_values or {}:
277
+ if task not in last_use_index:
278
+ last_use_index[task] = len(schedule) + 1
279
+
280
+ return ExecutionSchedule(tasks=schedule, last_use_index=last_use_index)
281
+
282
+
283
+ class Executor:
284
+ math_device: torch.device
285
+ storage_device: torch.device
286
+ universe: TaskUniverse
287
+ targets: List[TaskHandle]
288
+ schedule: ExecutionSchedule
289
+ cached_values: Optional[Dict[TaskHandle, Any]]
290
+ _task_counter: int
291
+
292
+ def __init__(
293
+ self,
294
+ targets: Union[List[Task], List[TaskHandle]],
295
+ math_device: torch.device = torch.device("cpu"),
296
+ storage_device: torch.device = torch.device("cpu"),
297
+ cached_values: Optional[Dict[TaskHandle, Any]] = None,
298
+ ):
299
+ self.cached_values = cached_values
300
+ self._task_counter = 0
301
+
302
+ if isinstance(math_device, str):
303
+ math_device = torch.device(math_device)
304
+ if isinstance(storage_device, str):
305
+ storage_device = torch.device(storage_device)
306
+ self.math_device = math_device
307
+ self.storage_device = storage_device
308
+
309
+ if targets and isinstance(targets[0], Task):
310
+ universe = TaskUniverse(targets)
311
+ targets = [universe.add_task(t) for t in targets]
312
+ elif targets and isinstance(targets[0], TaskHandle):
313
+ universe = targets[0]._universe
314
+ elif not targets:
315
+ universe = TaskUniverse()
316
+ else:
317
+ raise ValueError("Targets must be a list of Task or TaskHandle instances")
318
+
319
+ self.universe = universe
320
+ self.targets = targets
321
+ self.schedule = build_schedule(targets, cached_values=cached_values)
322
+
323
+ def _slice_argument(self, arg: Any, start: int, end: int) -> Any:
324
+ """Recursively slice tensors within nested structures."""
325
+ if isinstance(arg, torch.Tensor):
326
+ if arg.shape[0] > 1:
327
+ return arg[start:end]
328
+ return arg
329
+ elif isinstance(arg, dict):
330
+ return {k: self._slice_argument(v, start, end) for k, v in arg.items()}
331
+ elif isinstance(arg, list):
332
+ return [self._slice_argument(v, start, end) for v in arg]
333
+ elif isinstance(arg, tuple):
334
+ return tuple(self._slice_argument(v, start, end) for v in arg)
335
+ return arg
336
+
337
+ def _get_memory_stats(self) -> Dict[str, float]:
338
+ """Get current VRAM statistics in GB."""
339
+ if self.math_device.type != "cuda":
340
+ return {}
341
+
342
+ allocated = torch.cuda.memory_allocated(self.math_device) / (1024**3)
343
+ reserved = torch.cuda.memory_reserved(self.math_device) / (1024**3)
344
+ total = torch.cuda.get_device_properties(self.math_device).total_memory / (1024**3)
345
+
346
+ return {
347
+ "allocated_gb": allocated,
348
+ "reserved_gb": reserved,
349
+ "total_gb": total,
350
+ "free_gb": total - allocated,
351
+ }
352
+
353
+ def _get_adaptive_chunk_size(self, task: Task, arguments: Dict[str, Any]) -> int:
354
+ """
355
+ Calculate optimal chunk size based on available VRAM and task requirements.
356
+
357
+ This implements the "measure.py strategy" of targeting a specific VRAM fill level
358
+ rather than using currently available memory, which prevents oscillation.
359
+ """
360
+ if self.math_device.type == "cpu":
361
+ return 1024 # Large default for CPU
362
+
363
+ # Get hardware capacity
364
+ total_vram = torch.cuda.get_device_properties(self.math_device).total_memory
365
+ target_bytes = TARGET_VRAM_GB * (1024**3)
366
+
367
+ # Analyze tensor dimensions and count
368
+ num_tensors = 0
369
+ width = 0
370
+ bytes_per_element = 4 # Default float32
371
+
372
+ for arg in arguments.values():
373
+ if isinstance(arg, torch.Tensor):
374
+ num_tensors += 1
375
+ width = max(width, arg.shape[-1] if len(arg.shape) > 1 else arg.shape[0])
376
+ bytes_per_element = arg.element_size()
377
+ elif isinstance(arg, dict):
378
+ for v in arg.values():
379
+ if isinstance(v, torch.Tensor):
380
+ num_tensors += 1
381
+ width = max(width, v.shape[-1] if len(v.shape) > 1 else v.shape[0])
382
+ bytes_per_element = v.element_size()
383
+
384
+ if num_tensors == 0 or width == 0:
385
+ return 512 # Safe default
386
+
387
+ # Get task-specific multiplier
388
+ task_name = type(task).__name__
389
+ multiplier = TASK_MULTIPLIERS.get("default", 1.2)
390
+
391
+ for key, mult in TASK_MULTIPLIERS.items():
392
+ if key in task_name:
393
+ multiplier = mult
394
+ break
395
+
396
+ # Calculate bytes per row with multiplier for working memory
397
+ bytes_per_row = num_tensors * width * bytes_per_element * multiplier
398
+
399
+ # Calculate usable VRAM (target minus current allocation and safety margin)
400
+ current_allocated = torch.cuda.memory_allocated(self.math_device)
401
+ safety_bytes = VRAM_SAFETY_MARGIN_GB * (1024**3)
402
+ usable_vram = max(target_bytes - current_allocated - safety_bytes, 1024 * (1024**2))
403
+
404
+ # Calculate chunk size
405
+ chunk_size = max(MIN_CHUNK_SIZE, int(usable_vram // bytes_per_row))
406
+
407
+ # Apply power-of-2 alignment if enabled (measure.py strategy)
408
+ if ENABLE_POWER_OF_2_ALIGNMENT and chunk_size > MIN_CHUNK_SIZE:
409
+ chunk_size = _round_to_power_of_2(chunk_size, prefer_lower=True)
410
+
411
+ LOG.debug(f"Calculated chunk size: {chunk_size} (tensors={num_tensors}, width={width}, mult={multiplier:.2f})")
412
+ return chunk_size
413
+
414
+ def _execute_chunked(self, task: Task, arguments: Dict[str, Any]) -> Any:
415
+ """
416
+ Execute task in chunks with progressive fallback strategy.
417
+
418
+ Strategy:
419
+ 1. Try adaptive chunk size
420
+ 2. On OOM, reduce by CHUNK_REDUCTION_FACTOR
421
+ 3. Continue until success or MIN_CHUNK_SIZE reached
422
+ """
423
+ # Find total rows to process
424
+ total_rows = 0
425
+ for arg in arguments.values():
426
+ if isinstance(arg, torch.Tensor):
427
+ total_rows = arg.shape[0]
428
+ break
429
+ elif isinstance(arg, dict):
430
+ for v in arg.values():
431
+ if isinstance(v, torch.Tensor):
432
+ total_rows = v.shape[0]
433
+ break
434
+ if total_rows > 0:
435
+ break
436
+
437
+ if total_rows == 0:
438
+ return task.execute(**arguments)
439
+
440
+ # Calculate initial chunk size
441
+ chunk_size = self._get_adaptive_chunk_size(task, arguments)
442
+
443
+ # FAST PATH: Try to execute all at once if chunk size >= total rows
444
+ if ENABLE_FAST_PATH and chunk_size >= total_rows:
445
+ try:
446
+ gpu_args = {
447
+ k: self._move_tensors(v, self.math_device)
448
+ for k, v in arguments.items()
449
+ }
450
+ res = task.execute(**gpu_args)
451
+ result = self._move_tensors(res, self.storage_device)
452
+ del gpu_args, res
453
+ if ENABLE_AGGRESSIVE_CLEANUP:
454
+ torch.cuda.empty_cache()
455
+ return result
456
+ except torch.OutOfMemoryError:
457
+ LOG.warning(f"Fast path OOM, falling back to chunking")
458
+ torch.cuda.empty_cache()
459
+ gc.collect()
460
+ chunk_size = max(MIN_CHUNK_SIZE, total_rows // 2)
461
+
462
+ # Chunked execution with progressive reduction
463
+ results = []
464
+ i = 0
465
+ oom_count = 0
466
+
467
+ while i < total_rows:
468
+ end = min(i + chunk_size, total_rows)
469
+
470
+ try:
471
+ chunk_args_gpu = {
472
+ k: self._move_tensors(self._slice_argument(v, i, end), self.math_device)
473
+ for k, v in arguments.items()
474
+ }
475
+ chunk_res = task.execute(**chunk_args_gpu)
476
+ results.append(self._move_tensors(chunk_res, self.storage_device))
477
+
478
+ del chunk_args_gpu, chunk_res
479
+
480
+ # Aggressive cleanup per measure.py strategy
481
+ if ENABLE_AGGRESSIVE_CLEANUP:
482
+ torch.cuda.empty_cache()
483
+
484
+ i = end # Move to next chunk
485
+ oom_count = 0 # Reset OOM counter on success
486
+
487
+ except torch.OutOfMemoryError:
488
+ oom_count += 1
489
+ torch.cuda.empty_cache()
490
+ gc.collect()
491
+
492
+ # Progressive reduction
493
+ old_chunk = chunk_size
494
+ chunk_size = max(MIN_CHUNK_SIZE, int(chunk_size * CHUNK_REDUCTION_FACTOR))
495
+
496
+ # Apply power-of-2 alignment
497
+ if ENABLE_POWER_OF_2_ALIGNMENT:
498
+ chunk_size = _round_to_power_of_2(chunk_size, prefer_lower=True)
499
+
500
+ if chunk_size < MIN_CHUNK_SIZE:
501
+ LOG.error(f"Chunk size below minimum ({MIN_CHUNK_SIZE}), cannot continue")
502
+ raise
503
+
504
+ LOG.warning(
505
+ f"OOM at chunk {old_chunk}, reducing to {chunk_size} "
506
+ f"(attempt {oom_count}, progress: {i}/{total_rows})"
507
+ )
508
+
509
+ # Safety: if we OOM too many times, something is wrong
510
+ if oom_count > 10:
511
+ LOG.error("Too many OOM errors, giving up")
512
+ raise
513
+
514
+ # Concatenate results
515
+ if not results:
516
+ return None
517
+
518
+ if isinstance(results[0], torch.Tensor):
519
+ return torch.cat(results, dim=0)
520
+ elif isinstance(results[0], dict):
521
+ out = {}
522
+ for k in results[0].keys():
523
+ out[k] = torch.cat([r[k] for r in results], dim=0)
524
+ return out
525
+
526
+ return results
527
+
528
+ def _execute_with_fallback(self, task: Task, arguments: Dict[str, Any], accelerator) -> Any:
529
+ """
530
+ Execute task with comprehensive fallback strategy.
531
+
532
+ Strategy:
533
+ 1. Try full GPU execution
534
+ 2. Try adaptive chunking
535
+ 3. Try fixed chunk sizes
536
+ 4. Fall back to CPU
537
+ """
538
+ task_name = type(task).__name__
539
+
540
+ # Strategy 1: Try full GPU execution for light tasks
541
+ try:
542
+ gpu_args = {
543
+ k: self._move_tensors(v, self.math_device)
544
+ for k, v in arguments.items()
545
+ }
546
+ res = task.execute(**gpu_args)
547
+ result = self._move_tensors(res, self.storage_device)
548
+ del gpu_args, res
549
+ return result
550
+ except torch.OutOfMemoryError:
551
+ LOG.debug(f"Full GPU execution failed for {task_name}, trying chunked")
552
+ torch.cuda.empty_cache()
553
+ gc.collect()
554
+ except Exception as e:
555
+ LOG.warning(f"GPU execution error for {task_name}: {e}")
556
+ torch.cuda.empty_cache()
557
+ raise
558
+
559
+ # Strategy 2: Try adaptive chunking
560
+ try:
561
+ result = self._execute_chunked(task, arguments)
562
+ return result
563
+ except torch.OutOfMemoryError:
564
+ LOG.warning(f"Adaptive chunking failed for {task_name}, trying fixed sizes")
565
+ torch.cuda.empty_cache()
566
+ gc.collect()
567
+ except Exception as e:
568
+ LOG.warning(f"Chunking error for {task_name}: {e}")
569
+ raise
570
+
571
+ # Strategy 3: Try fixed chunk sizes
572
+ for chunk_size in FALLBACK_CHUNK_SIZES:
573
+ if chunk_size < MIN_CHUNK_SIZE:
574
+ continue
575
+
576
+ try:
577
+ LOG.info(f"Trying fixed chunk size {chunk_size} for {task_name}")
578
+
579
+ # Get total rows
580
+ total_rows = 0
581
+ for arg in arguments.values():
582
+ if isinstance(arg, torch.Tensor):
583
+ total_rows = arg.shape[0]
584
+ break
585
+ elif isinstance(arg, dict):
586
+ for v in arg.values():
587
+ if isinstance(v, torch.Tensor):
588
+ total_rows = v.shape[0]
589
+ break
590
+ if total_rows > 0:
591
+ break
592
+
593
+ if total_rows == 0:
594
+ break
595
+
596
+ results = []
597
+ for i in range(0, total_rows, chunk_size):
598
+ end = min(i + chunk_size, total_rows)
599
+ chunk_args = {
600
+ k: self._slice_argument(v, i, end)
601
+ for k, v in arguments.items()
602
+ }
603
+ chunk_args_gpu = {
604
+ k: self._move_tensors(v, self.math_device)
605
+ for k, v in chunk_args.items()
606
+ }
607
+ chunk_res = task.execute(**chunk_args_gpu)
608
+ results.append(self._move_tensors(chunk_res, self.storage_device))
609
+ del chunk_args, chunk_args_gpu, chunk_res
610
+
611
+ if ENABLE_AGGRESSIVE_CLEANUP:
612
+ torch.cuda.empty_cache()
613
+
614
+ if isinstance(results[0], torch.Tensor):
615
+ return torch.cat(results, dim=0)
616
+ elif isinstance(results[0], dict):
617
+ out = {}
618
+ for k in results[0].keys():
619
+ out[k] = torch.cat([r[k] for r in results], dim=0)
620
+ return out
621
+ return results
622
+
623
+ except torch.OutOfMemoryError:
624
+ torch.cuda.empty_cache()
625
+ gc.collect()
626
+ continue
627
+ except Exception as e:
628
+ LOG.warning(f"Fixed chunk {chunk_size} failed: {e}")
629
+ break
630
+
631
+ # Strategy 4: CPU fallback
632
+ LOG.warning(f"All GPU strategies failed for {task_name}, using CPU")
633
+ raise torch.OutOfMemoryError("Forcing CPU fallback")
634
+
635
+ def _run(
636
+ self,
637
+ quiet: bool = False,
638
+ desc: Optional[str] = None,
639
+ ) -> Iterator[Tuple[TaskHandle, Any]]:
640
+ last_use_index = self.schedule.last_use_index
641
+
642
+ values: Dict[TaskHandle, Any] = {}
643
+ if self.cached_values:
644
+ for task, value in self.cached_values.items():
645
+ values[task] = value
646
+
647
+ is_gpu_execution = self.math_device.type != "cpu"
648
+ accelerator = get_torch_accelerator_module(self.math_device.type) if is_gpu_execution else None
649
+
650
+ for idx, task_handle in (
651
+ pbar := tqdm.tqdm(
652
+ list(enumerate(self.schedule.tasks)),
653
+ disable=quiet,
654
+ desc=desc or "Executing graph",
655
+ )
656
+ ):
657
+ task = task_handle.task()
658
+ task_type = type(task).__name__
659
+
660
+ # Log memory stats periodically
661
+ if is_gpu_execution and idx % 10 == 0:
662
+ stats = self._get_memory_stats()
663
+ LOG.debug(
664
+ f"Memory: {stats.get('allocated_gb', 0):.2f}GB allocated, "
665
+ f"{stats.get('free_gb', 0):.2f}GB free of {stats.get('total_gb', 0):.2f}GB"
666
+ )
667
+
668
+ # Determine execution strategy
669
+ is_cpu_only_task = task_type in CPU_ONLY_TASKS
670
+ want_gpu = is_gpu_execution and task.uses_accelerator() and not is_cpu_only_task
671
+
672
+ # Collect arguments
673
+ arguments = {k: values[h] for k, h in task_handle.arguments().items()}
674
+
675
+ success = False
676
+
677
+ # Try GPU execution
678
+ if want_gpu:
679
+ try:
680
+ res = self._execute_with_fallback(task, arguments, accelerator)
681
+ values[task_handle] = res
682
+ success = True
683
+ except torch.OutOfMemoryError:
684
+ LOG.warning(f"All GPU strategies exhausted for {task_type}, falling back to CPU")
685
+ success = False
686
+ except Exception as e:
687
+ LOG.error(f"GPU execution failed for {task_type}: {e}")
688
+ success = False
689
+
690
+ # Cleanup after GPU attempt
691
+ if is_gpu_execution and ENABLE_AGGRESSIVE_CLEANUP:
692
+ gc.collect()
693
+ if accelerator:
694
+ accelerator.empty_cache()
695
+
696
+ # CPU fallback
697
+ if not success:
698
+ if want_gpu:
699
+ LOG.info(f"Executing {task_type} on CPU")
700
+
701
+ # Ensure cleanup before CPU execution
702
+ if is_gpu_execution:
703
+ gc.collect()
704
+ if accelerator:
705
+ accelerator.empty_cache()
706
+
707
+ # Move arguments to CPU
708
+ cpu_arguments = {
709
+ k: self._move_tensors(v, torch.device("cpu"))
710
+ for k, v in arguments.items()
711
+ }
712
+
713
+ res = task.execute(**cpu_arguments)
714
+ del cpu_arguments
715
+ res = self._move_tensors(res, self.storage_device)
716
+ values[task_handle] = res
717
+
718
+ del res
719
+ del arguments
720
+
721
+ if task_handle in self.targets:
722
+ yield (task_handle, values[task_handle])
723
+
724
+ # Evict unreferenced values
725
+ expired = []
726
+ for key in values:
727
+ if idx >= last_use_index[key]:
728
+ expired.append(key)
729
+ for key in expired:
730
+ del values[key]
731
+
732
+ # Periodic cleanup (measure.py strategy)
733
+ self._task_counter += 1
734
+ if is_gpu_execution and ENABLE_AGGRESSIVE_CLEANUP:
735
+ if CLEANUP_FREQUENCY == 0 or self._task_counter % max(1, CLEANUP_FREQUENCY) == 0:
736
+ gc.collect()
737
+ if accelerator:
738
+ accelerator.empty_cache()
739
+
740
+ del values
741
+ del pbar
742
+
743
+ def run(
744
+ self,
745
+ quiet: bool = False,
746
+ desc: Optional[str] = None,
747
+ ) -> Iterator[Tuple[Task, Any]]:
748
+ for handle, value in self._run(quiet=quiet, desc=desc):
749
+ yield (handle.task(), value)
750
+
751
+ def execute(self, desc: Optional[str] = None) -> None:
752
+ for _ in self.run(desc=desc):
753
+ pass
754
+
755
+ def _move_tensors(
756
+ self, value: Any, device: torch.device, non_blocking: Optional[bool] = None
757
+ ) -> Any:
758
+ """Move tensors to specified device, handling nested structures."""
759
+ if non_blocking is None:
760
+ non_blocking = device.type in ["cuda", "xpu"]
761
+
762
+ if isinstance(value, torch.Tensor):
763
+ if value.device == device:
764
+ return value
765
+ return value.to(device=device, non_blocking=non_blocking)
766
+ elif isinstance(value, dict):
767
+ return {
768
+ k: self._move_tensors(v, device, non_blocking)
769
+ for k, v in value.items()
770
+ }
771
+ elif isinstance(value, list):
772
+ return [self._move_tensors(v, device, non_blocking) for v in value]
773
+ elif isinstance(value, tuple):
774
+ return tuple(self._move_tensors(v, device, non_blocking) for v in value)
775
+
776
+ return value