| from collections.abc import Sequence | |
| import torch.fx as fx | |
| __all__ = ["set_trace"] | |
| def set_trace(gm: fx.GraphModule) -> fx.GraphModule: | |
| """ | |
| Sets a breakpoint in `gm`'s generated python code. It drops into pdb when | |
| `gm` gets run. | |
| Args: | |
| gm: graph module to insert breakpoint. It is then recompiled for it to | |
| take effect. | |
| Returns: | |
| the `gm` with breakpoint inserted. | |
| """ | |
| def insert_pdb(body: Sequence[str]) -> list[str]: | |
| return ["import pdb; pdb.set_trace()\n", *body] | |
| with gm.graph.on_generate_code( | |
| make_transformer=lambda cur_transform: ( | |
| # new code transformer to register | |
| lambda body: (insert_pdb(cur_transform(body) if cur_transform else body)) | |
| ) | |
| ): | |
| gm.recompile() | |
| return gm | |