File size: 844 Bytes
f4cade0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from collections.abc import Sequence

import torch.fx as fx


__all__ = ["set_trace"]


def set_trace(gm: fx.GraphModule) -> fx.GraphModule:
    """

    Sets a breakpoint in `gm`'s generated python code. It drops into pdb when

    `gm` gets run.



    Args:

        gm: graph module to insert breakpoint. It is then recompiled for it to

            take effect.



    Returns:

        the `gm` with breakpoint inserted.

    """

    def insert_pdb(body: Sequence[str]) -> list[str]:
        return ["import pdb; pdb.set_trace()\n", *body]

    with gm.graph.on_generate_code(
        make_transformer=lambda cur_transform: (
            # new code transformer to register
            lambda body: (insert_pdb(cur_transform(body) if cur_transform else body))
        )
    ):
        gm.recompile()

    return gm