|
|
import torch.fx as fx |
|
|
|
|
|
def set_trace(gm: fx.GraphModule) -> fx.GraphModule: |
|
|
""" |
|
|
Sets a breakpoint in `gm`'s generated python code. It drops into pdb when |
|
|
`gm` gets run. |
|
|
|
|
|
Args: |
|
|
gm: graph module to insert breakpoint. It is then recompiled for it to |
|
|
take effect. |
|
|
|
|
|
Returns: |
|
|
the `gm` with breakpoint inserted. |
|
|
""" |
|
|
def insert_pdb(body): |
|
|
return ["import pdb; pdb.set_trace()\n", *body] |
|
|
|
|
|
with gm.graph.on_generate_code( |
|
|
make_transformer=lambda cur_transform: ( |
|
|
|
|
|
lambda body: ( |
|
|
insert_pdb( |
|
|
cur_transform(body) if cur_transform |
|
|
else body |
|
|
) |
|
|
) |
|
|
) |
|
|
): |
|
|
gm.recompile() |
|
|
|
|
|
return gm |
|
|
|