File size: 2,321 Bytes
d201410
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import comfy_aimdo.model_vbar
import comfy.model_management
import comfy.ops

PREFETCH_QUEUES = []

def cleanup_prefetched_modules(comfy_modules):
    for s in comfy_modules:
        prefetch = getattr(s, "_prefetch", None)
        if prefetch is None:
            continue
        for param_key in ("weight", "bias"):
            lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
            if lowvram_fn is not None:
                lowvram_fn.clear_prepared()
        if prefetch["signature"] is not None:
            comfy_aimdo.model_vbar.vbar_unpin(s._v)
        delattr(s, "_prefetch")

def cleanup_prefetch_queues():
    global PREFETCH_QUEUES

    for queue in PREFETCH_QUEUES:
        for entry in queue:
            if entry is None or not isinstance(entry, tuple):
                continue
            _, prefetch_state = entry
            comfy_modules = prefetch_state[1]
            if comfy_modules is not None:
                cleanup_prefetched_modules(comfy_modules)
    PREFETCH_QUEUES = []

def prefetch_queue_pop(queue, device, module):
    if queue is None:
        return

    consumed = queue.pop(0)
    if consumed is not None:
        offload_stream, prefetch_state = consumed
        if offload_stream is not None:
            offload_stream.wait_stream(comfy.model_management.current_stream(device))
        _, comfy_modules = prefetch_state
        if comfy_modules is not None:
            cleanup_prefetched_modules(comfy_modules)

    prefetch = queue[0]
    if prefetch is not None:
        comfy_modules = []
        for s in prefetch.modules():
            if hasattr(s, "_v"):
                comfy_modules.append(s)

        offload_stream = comfy.ops.cast_modules_with_vbar(comfy_modules, None, device, None, True)
        comfy.model_management.sync_stream(device, offload_stream)
        queue[0] = (offload_stream, (prefetch, comfy_modules))

def make_prefetch_queue(queue, device, transformer_options):
    if (not transformer_options.get("prefetch_dynamic_vbars", False)
        or comfy.model_management.NUM_STREAMS == 0
        or comfy.model_management.is_device_cpu(device)
        or not comfy.model_management.device_supports_non_blocking(device)):
        return None

    queue = [None] + queue + [None]
    PREFETCH_QUEUES.append(queue)
    return queue