File size: 17,090 Bytes
6f0b660
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
# Copyright 2025 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import json
import os
import re
from contextlib import contextmanager, redirect_stdout
from io import StringIO
from typing import Optional

from .utils import logging
from .utils.import_utils import is_torch_available, requires


if is_torch_available():
    import torch
    from safetensors.torch import save_file

    _torch_distributed_available = False
    # Note to code inspectors: this toolbox is intended for people who add models to `transformers`.
    if torch.distributed.is_available():
        import torch.distributed.tensor

        _torch_distributed_available = True
else:
    _torch_distributed_available = False


logger = logging.get_logger(__name__)


def _is_rank_zero():
    """Return True if rank=0 or we aren't running distributed."""
    if not (_torch_distributed_available and torch.distributed.is_initialized()):
        return True
    return torch.distributed.get_rank() == 0


MEMORY_ADDRESS_REGEX = re.compile(r"object at 0x[0-9A-Fa-f]+")


def _sanitize_repr_for_diff(x_str: str) -> str:
    """
    Replace memory addresses in an object's repr with a stable placeholder
    so that beautiful JSON diffs won't be ruined by ephemeral addresses.
    """
    return MEMORY_ADDRESS_REGEX.sub("object at 0xXXXXXXXX", x_str)


def _dtensor_repr(x):
    """Return a stable string representation for a DTensor-like object."""
    if _is_rank_zero():
        return f"DTensor (rank0) -> {repr(x._local_tensor)}"
    return "DTensor(non-rank0)"


def _serialize_tensor_like_io(
    value, debug_path: Optional[str] = None, use_repr: bool = True, path_to_value: Optional[str] = None
):
    """
    Converts Tensors and DTensors to a JSON-serializable dictionary representation.

    Args:
        value: Any Python object, often including torch Tensors, lists, dicts, etc.
        debug_path (`str`, *optional*, defaults to `None`): Directory to dump debug JSON and SafeTensors files.
        use_repr (bool, *optional*, defaults to `True`): Whether to save a `repr()`-ized version of the tensor as the
            `value` property in the asscoiated FULL_TENSORS.json file, or to store the full tensors in separate
            SafeTensors file and store the relative path to that file in the `value` property in the dictionary.
        path_to_value (`str`, *optional*, defaults to `None`): The file name for the SafeTensors file holding the full
            tensor value if `use_repr=False`.

    Returns:
        A nested Python structure (list, dict, or sanitized string) that is safe to json.dump.
    """
    torch.set_printoptions(sci_mode=True)

    if use_repr:
        value_out = _repr_to_list(value)
    elif path_to_value:
        if not path_to_value.endswith(".safetensors"):
            path_to_value += ".safetensors"

        filepath = os.path.join(debug_path, path_to_value) if debug_path else path_to_value
        save_file({"data": value.contiguous().detach().cpu()}, filepath)
        value_out = f"./{path_to_value}"
    else:
        raise ValueError(f"{use_repr=} and {path_to_value=} cannot both be falsy.")

    out = {
        "shape": repr(value.shape),
        "dtype": repr(value.dtype),
        "value": value_out,
    }
    if value.dtype in {torch.float16, torch.float32, torch.bfloat16}:
        out.update(
            {
                "mean": _sanitize_repr_for_diff(repr(value.mean())),
                "std": _sanitize_repr_for_diff(repr(value.std())),
                "min": _sanitize_repr_for_diff(repr(value.min())),
                "max": _sanitize_repr_for_diff(repr(value.max())),
            }
        )
    return out


def _serialize_io(value, debug_path: Optional[str] = None, use_repr: bool = True, path_to_value: Optional[str] = None):
    """
    Recursively build a JSON-serializable Python structure from `value`.
    Tensors and DTensors become either sanitized repr strings, or are saved to disk as SafeTensors files and their
    relative paths are recorded in the returned Python structure.
    Lists/tuples/dicts are recursed into.
    All memory addresses are replaced with a stable placeholder.

    Args:
        value: Any Python object, often including torch Tensors, lists, dicts, etc.
        debug_path (`str`, *optional*, defaults to `None`): Directory to dump debug JSON and SafeTensors files.
        use_repr (bool, *optional*, defaults to `True`): Whether to save a `repr()`-ized version of the tensors as the
            `value` property in the asscoiated FULL_TENSORS.json file, or to store full tensors in separate SafeTensors
            files and store the relative path to that file in the `value` property.
        path_to_value (`str`, *optional*, defaults to `None`): The file name for the SafeTensors file holding the full
            tensor value if `use_repr=False`.

    Returns:
        A nested Python structure (list, dict, or sanitized string) that is safe to json.dump.
    """
    if isinstance(value, (list, tuple)):
        return [
            _serialize_io(v, debug_path=debug_path, use_repr=use_repr, path_to_value=f"{path_to_value}_{i}")
            for i, v in enumerate(value)
        ]

    if isinstance(value, dict):
        return {
            k: _serialize_io(v, debug_path=debug_path, use_repr=use_repr, path_to_value=f"{path_to_value}_{k}")
            for k, v in value.items()
        }

    if hasattr(value, "_local_tensor"):
        return _serialize_tensor_like_io(
            value._local_tensor, debug_path=debug_path, use_repr=use_repr, path_to_value=path_to_value
        )

    if isinstance(value, torch.Tensor):
        return _serialize_tensor_like_io(value, debug_path=debug_path, use_repr=use_repr, path_to_value=path_to_value)

    return _sanitize_repr_for_diff(repr(value))


def _repr_to_list(value: torch.Tensor):
    """
    Converts a tensor into a sanitized multi-line string representation.

    Args:
        value (`torch.Tensor`): The tensor to represent.

    Returns:
        `list[str]`: List of string lines representing the tensor.
    """
    torch.set_printoptions(sci_mode=True, linewidth=120)
    with StringIO() as buf, redirect_stdout(buf):
        print(value)  # to redirected stdout to avoid line splits
        raw = buf.getvalue()
    return _sanitize_repr_for_diff(raw).splitlines()


def prune_outputs_if_children(node):
    # if there are children, remove this node's "outputs"
    # so we only see outputs at the leaf level
    if node.get("children"):
        node.pop("outputs", None)
        for child in node["children"]:
            prune_outputs_if_children(child)


LAYER_SUFFIX_RE = re.compile(r"(.*)\.(\d+)$")  # should be generic enough, ends with a number


def is_layer_block(node):
    """
    Checks whether a node represents a layer block with submodules.

    Args:
        node (`dict`): A node from the call tree.

    Returns:
        `bool`: Whether the node is a layer block.
    """
    match = LAYER_SUFFIX_RE.match(node.get("module_path", ""))
    if not match or not node.get("children"):
        return False
    number = match.group(2)
    return any(f".{number}." in child.get("module_path", "") for child in node["children"])


def prune_intermediate_layers(node):
    """
    Recursively removes intermediate layers from the tree to improve readability.
    Keeps at least the first and last layers if many consecutive layers are present.

    Args:
        node (`dict`): The root or subnode to prune recursively.
    """
    if not node.get("children"):
        return
    layer_blocks = [(i, child) for i, child in enumerate(node["children"]) if is_layer_block(child)]

    if len(layer_blocks) > 2:
        to_remove = [i for i, _ in layer_blocks[1:-1]]
        node["children"] = [child for i, child in enumerate(node["children"]) if i not in to_remove]

    for child in node["children"]:
        prune_intermediate_layers(child)


def log_model_debug_trace(debug_path: Optional[str], model):
    if debug_path:
        try:
            os.makedirs(debug_path, exist_ok=True)
            base = os.path.join(debug_path, model._debugger_module_dump_name + "_debug_tree")
        except Exception as e:
            raise ValueError(f"Unexpected or existing debug_path={debug_path}.") from e
    else:
        base = model._debugger_module_dump_name + "_debug_tree"

    logger.info(f"Writing model trace at {base}.json")
    full_path = base + "_FULL_TENSORS.json"
    summary_path = base + "_SUMMARY.json"

    prune_outputs_if_children(model._call_tree)

    with open(full_path, "w") as f:
        json.dump(model._call_tree, f, indent=2)

    # summary-only version for readability - traversing the tree again #TODO optimize?
    def strip_values(node):
        def clean(val):
            if isinstance(val, dict):
                val.pop("value", None)
                for v in val.values():
                    clean(v)
            elif isinstance(val, list):
                for item in val:
                    clean(item)

        clean(node.get("inputs", {}))
        clean(node.get("outputs", {}))

        for child in node.get("children", []):
            strip_values(child)

    tree_copy = json.loads(json.dumps(model._call_tree))  # deep copy
    strip_values(tree_copy)

    with open(summary_path, "w") as f:
        json.dump(tree_copy, f, indent=2)


def _attach_debugger_logic(
    model,
    debug_path: str = ".",
    do_prune_layers: bool = True,
    use_repr: bool = True,
):
    """
    Attaches a debugging wrapper to every module in the model.

    This records structured inputs and outputs during the forward pass into a call tree.

    Args:
        model (`PreTrainedModel`, `nn.Module`): Model to wrap.
        debug_path (`str`): Optional directory to dump debug JSON files.
        do_prune_layers (`bool`, *optional*, defaults to `True`): Whether to prune intermediate layers.
        use_repr (bool, *optional*, defaults to `True`): Whether to save a `repr()`-ized version of the tensors as the
            `value` property in the associated FULL_TENSORS.json file, or to store full tensors in separate SafeTensors
            files and store the relative path to that file in the `value` property.
    """
    class_name = model.__class__.__name__

    # Prepare data structures on the model object
    model._call_tree = {"module_path": class_name, "inputs": None, "outputs": None, "children": []}
    model._debugger_model_call_stack = []
    model._debugger_module_dump_name = class_name  # used for final JSON filename

    if debug_path:
        try:
            os.makedirs(debug_path, exist_ok=True)
        except Exception as e:
            raise ValueError(f"Unexpected or existing debug_path={debug_path}.") from e

    def wrap_forward(module, full_path):
        orig_forward = module.forward

        @functools.wraps(orig_forward)
        def wrapped_forward(*inps, **kws):
            if _is_rank_zero():
                dict_inputs = {"args": inps, "kwargs": kws}
                dict_inputs = {k: dict_inputs[k] for k in dict_inputs if len(dict_inputs[k]) > 0}
                node = {
                    "module_path": full_path,
                    "inputs": _serialize_io(
                        dict_inputs,
                        debug_path=debug_path,
                        use_repr=use_repr,
                        path_to_value=f"{full_path}_inputs",
                    ),
                    "outputs": None,
                    "children": [],
                }
                model._debugger_model_call_stack.append(node)
            with torch.no_grad():
                out = orig_forward(*inps, **kws)

            if _is_rank_zero():
                if sum(1 for _ in module.named_children()) > 0:
                    node["outputs"] = None
                else:
                    node["outputs"] = _serialize_io(
                        out,
                        debug_path=debug_path,
                        use_repr=use_repr,
                        path_to_value=f"{full_path}_outputs",
                    )

                finished = model._debugger_model_call_stack.pop()
                # prune empty vertices here as well (mostly empty children nodes)
                if not finished["children"]:
                    finished.pop("children")

                if model._debugger_model_call_stack:
                    model._debugger_model_call_stack[-1]["children"].append(finished)
            return out

        module.forward = wrapped_forward

    # wrap all submodules
    for name, submodule in model.named_modules():
        if name == "":
            continue
        wrap_forward(submodule, f"{class_name}.{name}")

    # wrap top-level forward
    real_top_forward = model.forward

    @functools.wraps(real_top_forward)
    def top_wrapped_forward(*inps, **kws):
        if _is_rank_zero():
            top_node = {
                "module_path": f"{class_name} (top-level)",
                "inputs": _serialize_io(
                    {"args": inps, "kwargs": kws},
                    debug_path=debug_path,
                    use_repr=use_repr,
                    path_to_value=f"{class_name}_inputs",
                ),
                "outputs": None,
                "children": [],
            }
            model._debugger_model_call_stack.append(top_node)

        out = real_top_forward(*inps, **kws)
        if _is_rank_zero() and model._debugger_model_call_stack:
            top_node["outputs"] = _serialize_io(
                out,
                debug_path=debug_path,
                use_repr=use_repr,
                path_to_value=f"{class_name}_outputs",
            )
            finished = model._debugger_model_call_stack.pop()
            model._call_tree["inputs"] = finished["inputs"]
            model._call_tree["outputs"] = finished["outputs"]
            model._call_tree["children"] = finished["children"]
            # prune empty stuff for visibility
            [model._call_tree.pop(k, None) for k in list(model._call_tree.keys()) if not model._call_tree[k]]

            # prune layers that are not 0 or last
            if do_prune_layers:
                prune_intermediate_layers(model._call_tree)
            # Write final JSON trace here
            log_model_debug_trace(debug_path=debug_path, model=model)
        return out

    model.forward = top_wrapped_forward


@requires(backends=("torch",))
@contextmanager
def model_addition_debugger_context(
    model,
    debug_path: Optional[str] = None,
    do_prune_layers: bool = True,
    use_repr: bool = True,
):
    """
    # Model addition debugger - context manager for model adders
    This context manager is a power user tool intended for model adders.

    It tracks all forward calls within a model forward and logs a slice of each input and output on a nested JSON file.
    If `use_repr=True` (the default), the JSON file will record a `repr()`-ized version of the tensors as a list of
    strings. If `use_repr=False`, the full tensors will be stored in separate SafeTensors files and the JSON file will
    provide a relative path to that file.

    To note, this context manager enforces `torch.no_grad()`.

    ## Usage

    add the context manager to a model to debug

    ```python
    import torch

    from PIL import Image
    from transformers import LlavaProcessor, LlavaForConditionalGeneration, model_addition_debugger_context

    torch.random.manual_seed(673)

    # load pretrained model and processor
    model_id = "llava-hf/llava-1.5-7b-hf"
    processor = LlavaProcessor.from_pretrained(model_id)
    model = LlavaForConditionalGeneration.from_pretrained(model_id)

    # create random image input
    random_image = Image.fromarray(torch.randint(0, 256, (224, 224, 3), dtype=torch.uint8).numpy())

    # prompt
    prompt = "<image>Describe this image."

    # process inputs
    inputs = processor(text=prompt, images=random_image, return_tensors="pt")

    # call forward method (not .generate!)
    with model_addition_debugger_context(model, debug_path="Your_debug_path", do_prune_layers=False):
        output = model.forward(**inputs)
    ```

    """
    orig_forwards = {m: m.forward for _, m in model.named_modules()}
    orig_forwards[model] = model.forward
    _attach_debugger_logic(model, debug_path, do_prune_layers, use_repr)
    try:
        yield model
    finally:
        for module_instance, forward_method in orig_forwards.items():
            module_instance.forward = forward_method