| import comfy_aimdo.model_vbar |
| import comfy.model_management |
| import comfy.ops |
|
|
| PREFETCH_QUEUES = [] |
|
|
| def cleanup_prefetched_modules(comfy_modules): |
| for s in comfy_modules: |
| prefetch = getattr(s, "_prefetch", None) |
| if prefetch is None: |
| continue |
| for param_key in ("weight", "bias"): |
| lowvram_fn = getattr(s, param_key + "_lowvram_function", None) |
| if lowvram_fn is not None: |
| lowvram_fn.clear_prepared() |
| if prefetch["signature"] is not None: |
| comfy_aimdo.model_vbar.vbar_unpin(s._v) |
| delattr(s, "_prefetch") |
|
|
| def cleanup_prefetch_queues(): |
| global PREFETCH_QUEUES |
|
|
| for queue in PREFETCH_QUEUES: |
| for entry in queue: |
| if entry is None or not isinstance(entry, tuple): |
| continue |
| _, prefetch_state = entry |
| comfy_modules = prefetch_state[1] |
| if comfy_modules is not None: |
| cleanup_prefetched_modules(comfy_modules) |
| PREFETCH_QUEUES = [] |
|
|
| def prefetch_queue_pop(queue, device, module): |
| if queue is None: |
| return |
|
|
| consumed = queue.pop(0) |
| if consumed is not None: |
| offload_stream, prefetch_state = consumed |
| if offload_stream is not None: |
| offload_stream.wait_stream(comfy.model_management.current_stream(device)) |
| _, comfy_modules = prefetch_state |
| if comfy_modules is not None: |
| cleanup_prefetched_modules(comfy_modules) |
|
|
| prefetch = queue[0] |
| if prefetch is not None: |
| comfy_modules = [] |
| for s in prefetch.modules(): |
| if hasattr(s, "_v"): |
| comfy_modules.append(s) |
|
|
| offload_stream = comfy.ops.cast_modules_with_vbar(comfy_modules, None, device, None, True) |
| comfy.model_management.sync_stream(device, offload_stream) |
| queue[0] = (offload_stream, (prefetch, comfy_modules)) |
|
|
| def make_prefetch_queue(queue, device, transformer_options): |
| if (not transformer_options.get("prefetch_dynamic_vbars", False) |
| or comfy.model_management.NUM_STREAMS == 0 |
| or comfy.model_management.is_device_cpu(device) |
| or not comfy.model_management.device_supports_non_blocking(device)): |
| return None |
|
|
| queue = [None] + queue + [None] |
| PREFETCH_QUEUES.append(queue) |
| return queue |
|
|