File size: 4,873 Bytes
e00eceb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import math
import ctypes
import threading
import dataclasses
import torch
from typing import NamedTuple

from comfy.quant_ops import QuantizedTensor


class TensorFileSlice(NamedTuple):
    file_ref: object
    thread_id: int
    offset: int
    size: int


def read_tensor_file_slice_into(tensor, destination):

    if isinstance(tensor, QuantizedTensor):
        if not isinstance(destination, QuantizedTensor):
            return False
        if tensor._layout_cls != destination._layout_cls:
            return False

        if not read_tensor_file_slice_into(tensor._qdata, destination._qdata):
            return False

        dst_orig_dtype = destination._params.orig_dtype
        destination._params.copy_from(tensor._params, non_blocking=False)
        destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype)
        return True

    info = getattr(tensor.untyped_storage(), "_comfy_tensor_file_slice", None)
    if info is None:
        return False

    file_obj = info.file_ref
    if (destination.device.type != "cpu"
            or file_obj is None
            or threading.get_ident() != info.thread_id
            or destination.numel() * destination.element_size() < info.size
            or tensor.numel() * tensor.element_size() != info.size
            or tensor.storage_offset() != 0
            or not tensor.is_contiguous()):
        return False

    if info.size == 0:
        return True

    buf_type = ctypes.c_ubyte * info.size
    view = memoryview(buf_type.from_address(destination.data_ptr()))

    try:
        file_obj.seek(info.offset)
        done = 0
        while done < info.size:
            try:
                n = file_obj.readinto(view[done:])
            except OSError:
                return False
            if n <= 0:
                return False
            done += n
        return True
    finally:
        view.release()

class TensorGeometry(NamedTuple):
    shape: any
    dtype: torch.dtype

    def element_size(self):
        info = torch.finfo(self.dtype) if self.dtype.is_floating_point else torch.iinfo(self.dtype)
        return info.bits // 8

    def numel(self):
        return math.prod(self.shape)

def tensors_to_geometries(tensors, dtype=None):
    geometries = []
    for t in tensors:
        if t is None or isinstance(t, QuantizedTensor):
            geometries.append(t)
            continue
        tdtype = t.dtype
        if hasattr(t, "_model_dtype"):
            tdtype = t._model_dtype
        if dtype is not None:
            tdtype = dtype
        geometries.append(TensorGeometry(shape=t.shape, dtype=tdtype))
    return geometries

def vram_aligned_size(tensor):
    if isinstance(tensor, list):
        return sum([vram_aligned_size(t) for t in tensor])

    if isinstance(tensor, QuantizedTensor):
        inner_tensors, _ = tensor.__tensor_flatten__()
        return vram_aligned_size([ getattr(tensor, attr) for attr in inner_tensors ])

    if tensor is None:
        return 0

    size = tensor.numel() * tensor.element_size()
    aligment_req = 1024
    return (size + aligment_req - 1) // aligment_req * aligment_req

def interpret_gathered_like(tensors, gathered):
    offset = 0
    dest_views = []

    if gathered.dim() != 1 or gathered.element_size() != 1:
        raise ValueError(f"Buffer must be 1D and single-byte (got {gathered.dim()}D {gathered.dtype})")

    for tensor in tensors:

        if tensor is None:
            dest_views.append(None)
            continue

        if isinstance(tensor, QuantizedTensor):
            inner_tensors, qt_ctx = tensor.__tensor_flatten__()
            templates = { attr: getattr(tensor, attr) for attr in inner_tensors }
        else:
            templates = { "data": tensor }

        actuals = {}
        for attr, template in templates.items():
            size = template.numel() * template.element_size()
            if offset + size > gathered.numel():
                raise ValueError(f"Buffer too small: needs {offset + size} bytes, but only has {gathered.numel()}. ")
            actuals[attr] = gathered[offset:offset+size].view(dtype=template.dtype).view(template.shape)
            offset += vram_aligned_size(template)

        if isinstance(tensor, QuantizedTensor):
            dest_views.append(QuantizedTensor.__tensor_unflatten__(actuals, qt_ctx, 0, 0))
        else:
            dest_views.append(actuals["data"])

    return dest_views

aimdo_enabled = False

extra_ram_release_callback = None
RAM_CACHE_HEADROOM = 0

def set_ram_cache_release_state(callback, headroom):
    global extra_ram_release_callback
    global RAM_CACHE_HEADROOM
    extra_ram_release_callback = callback
    RAM_CACHE_HEADROOM = max(0, int(headroom))

def extra_ram_release(target):
    if extra_ram_release_callback is None:
        return 0
    return extra_ram_release_callback(target)