| | import contextlib |
| | import functools |
| | from dataclasses import dataclass, field |
| | from typing import Any, Callable, Dict, Generator, List, Optional, Tuple |
| |
|
| | import torch |
| |
|
| |
|
| | __all__ = ["TracingConfig"] |
| |
|
| |
|
| | @dataclass |
| | class TracingConfig: |
| | """ |
| | Configurations used in ``ParamExecOrderWrapPolicy`` for symbolic tracing of |
| | a model. |
| | |
| | Args: |
| | tracer (torch.fx.Tracer): An instance of ``torch.fx.Tracer`` that will |
| | be used to perform symbolic tracing. ``tracer`` is default to be |
| | ``torch.fx.Tracer()``, but can also be instance of some child class |
| | of ``torch.fx.Tracer``. For example, one may want to use |
| | ``HFTracer`` for models in Transformers: .. _Transformers: |
| | https://huggingface.co/docs/transformers/index |
| | concrete_args (Optional[Dict[str, Any]]): Concrete arguments that should |
| | not be treated as ``torch.fx.Proxy`` when tracing the forward |
| | function. ``concrete_args`` allows one to partially specialize the |
| | forward function, including removing control flow or data |
| | structures. ``concrete_args`` is also the argument used in |
| | :meth:`~torch.fx.Tracer.trace`. |
| | """ |
| |
|
| | tracer: torch.fx.Tracer = torch.fx.Tracer() |
| | concrete_args: Optional[Dict[str, Any]] = None |
| |
|
| |
|
| | @dataclass |
| | class _ExecutionInfo: |
| | """ |
| | Contains the execution order information in the model forward pass. |
| | |
| | Attributes: |
| | current_module: record the module that is currently being traced. |
| | |
| | module_forward_order: a list of modules, where the ordering is based on |
| | when their forward function is called. ``module_forward_order`` |
| | includes the info of how many times a module is called + used to |
| | check the forward order in different iterations. |
| | |
| | param_exec_order: a list of parameters ordered based on their execution |
| | order. |
| | |
| | module_to_execution_infos: a dict that maps each module to a list of |
| | tuples each containing a module and a list of named parameters. |
| | ``module_execution_info_dict`` is used as the parameter execution |
| | order info. For a given module, each tuple: 1. either contains this |
| | module and part of its ``named_parameters`` that will be executed |
| | together, 2. or contains one of its child modules and all of the |
| | child module's ``named_parameters``. The list of tuples is ordered |
| | based on the parameter execution order. |
| | """ |
| |
|
| | current_module: torch.nn.Module |
| | module_forward_order: List[torch.nn.Module] |
| | module_to_execution_infos: Dict[ |
| | torch.nn.Module, |
| | List[Tuple[torch.nn.Module, List[Tuple[str, torch.nn.Parameter]]]], |
| | ] |
| | param_exec_order: List[torch.nn.Parameter] = field(default_factory=list) |
| |
|
| |
|
| | def _init_execution_info(root_module: torch.nn.Module) -> _ExecutionInfo: |
| | """ |
| | Create an instance of _ExecutionInfo with initialization based on |
| | ``root_module``. |
| | |
| | Args: |
| | root_module (torch.nn.Module): the module to get the execution |
| | information via ``tracer.trace()`` inside ``_patch_tracer``. |
| | """ |
| | return _ExecutionInfo( |
| | current_module=root_module, |
| | module_forward_order=[root_module], |
| | module_to_execution_infos={root_module: []}, |
| | ) |
| |
|
| |
|
| | def _patched_create_proxy( |
| | create_proxy: Callable, |
| | execution_info: _ExecutionInfo, |
| | prefixed_param_name_to_param: Dict[str, torch.nn.Parameter], |
| | kind: str, |
| | target: torch.fx.node.Target, |
| | args: Tuple[Any, ...], |
| | kwargs: Dict[str, Any], |
| | name: Optional[str] = None, |
| | type_expr: Optional[Any] = None, |
| | proxy_factory_fn: Callable[[torch.fx.Node], torch.fx.Proxy] = None, |
| | ) -> torch.fx.Proxy: |
| | """ |
| | Override of :meth:`~torch.fx.Tracer.create_proxy`. ``Tracer.create_proxy`` |
| | is called in symbolic tracing for each leaf function/method/module. This |
| | override intercepts the recording of each of these operations to update |
| | ``execution_info.module_to_execution_infos``. |
| | |
| | Args: |
| | create_proxy (Callable): |
| | The ``create_proxy`` function to be patched. |
| | execution_info (_ExecutionInfo): |
| | Used to record the execution information. |
| | prefixed_param_name_to_param (Dict[str, torch.nn.Parameter]): |
| | A dict that maps each prefixed parameter name to the parameter. |
| | kind (str): |
| | The type of the target method. One of 'call_function', |
| | 'call_method', 'get_attr', 'call_module', 'placeholder', or |
| | 'output'. The semantics of these opcodes are described in the |
| | ``torch.fx.Graph`` docstring. This is the input to ``create_proxy``. |
| | target (torch.fx.node.Target): |
| | Contains the string name of the method. This is the input to |
| | ``create_proxy``. |
| | args (Tuple[Any, ...]): |
| | Arguments of the method. This is the input to ``create_proxy``. |
| | kwargs (Dict[str, Any]): |
| | Keyword arguments of the method. This is the input to |
| | ``create_proxy``. |
| | name (Optional[str]): |
| | An optional string name for the ``Node`` created in |
| | ``create_proxy``. This is the input to ``create_proxy``. |
| | type_expr (Optional[Any]): |
| | An optional type annotation representing the Python type the output |
| | of a node will have. This is the input to ``create_proxy``. |
| | proxy_factory_fn (Callable[[torch.fx.Node], torch.fx.Proxy]): |
| | An alternative proxy constructor used in ``create_proxy``. This is |
| | the input to ``create_proxy``. |
| | """ |
| | proxy = create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn) |
| |
|
| | module = execution_info.current_module |
| | if kind in ["call_function", "call_method"]: |
| | if args is not None: |
| | named_params: List[Tuple[str, torch.nn.Parameter]] = [] |
| | for arg in args: |
| | if isinstance(arg, torch.fx.Proxy) and arg.node.target in prefixed_param_name_to_param: |
| | param = prefixed_param_name_to_param[arg.node.target] |
| | named_params.append((arg.node.target, param)) |
| | if param not in set(execution_info.param_exec_order): |
| | execution_info.param_exec_order.append(param) |
| | if named_params: |
| | execution_info.module_to_execution_infos[module].append((module, named_params)) |
| | elif kind == "call_module": |
| | named_params = list(module.named_parameters()) |
| | if named_params: |
| | execution_info.module_to_execution_infos[module].append( |
| | (module, named_params) |
| | ) |
| | for (_, p) in named_params: |
| | if p not in set(execution_info.param_exec_order): |
| | execution_info.param_exec_order.append(p) |
| | return proxy |
| |
|
| |
|
| | def _patched_call_module( |
| | call_module: Callable, |
| | execution_info: _ExecutionInfo, |
| | module: torch.nn.Module, |
| | forward: Callable[..., Any], |
| | args: Tuple[Any, ...], |
| | kwargs: Dict[str, Any], |
| | ) -> Any: |
| | """ |
| | Override of :meth:`~torch.fx.Tracer.call_module`. ``Tracer.call_module`` is |
| | called in symbolic tracing for each non-root module. This override |
| | intercepts the recording of each operation to update |
| | ``execution_info.module_forward_order`` and |
| | ``execution_info.module_to_execution_infos``. |
| | |
| | Args: |
| | call_module (Callable): |
| | The ``call_module`` function to be patched. |
| | execution_info (_ExecutionInfo): |
| | Used to repord the execution information. |
| | module (torch.nn.Module): |
| | The module for which a call is being emitted. |
| | forward (Callable[..., Any]): |
| | The ``forward()`` method of the ``torch.nn.Module`` to be invoked. |
| | args (Tuple[Any, ...]): |
| | ``args`` of the module callsite. |
| | kwargs (Dict[str, Any]): |
| | ``kwargs`` of the module callsite. |
| | """ |
| | execution_info.module_forward_order.append(module) |
| | named_params = list(module.named_parameters()) |
| | if named_params: |
| | execution_info.module_to_execution_infos[execution_info.current_module].append( |
| | (module, list(module.named_parameters())) |
| | ) |
| | |
| | prev_current_module = execution_info.current_module |
| | execution_info.current_module = module |
| | |
| | |
| | execution_info.module_to_execution_infos[module] = [] |
| | output = call_module(module, forward, args, kwargs) |
| | execution_info.current_module = prev_current_module |
| | return output |
| |
|
| |
|
| | @contextlib.contextmanager |
| | def _patch_tracer( |
| | tracer: torch.fx.Tracer, |
| | root_module: torch.nn.Module, |
| | execution_info: _ExecutionInfo, |
| | ) -> Generator: |
| | """ |
| | Within the context manager, patches the input tracer so that during |
| | ``tracer.trace()``, the forward order of all modules and the parameter |
| | execution information are recorded. The patches of the input tracer will be |
| | removed after the context manager exits. |
| | |
| | Args: |
| | tracer (torch.fx.Tracer): the input ``tracer`` whose member functions |
| | will be patched within the context manager. |
| | root_module (torch.nn.Module): the top-level module to be traced |
| | and should not contain any FSDP modules. |
| | execution_info (_ExecutionInfo): used to record the execution order |
| | information when performing ``tracer.trace()`` within the context |
| | manager. |
| | """ |
| | original_call_module = tracer.call_module |
| | original_create_proxy = tracer.create_proxy |
| |
|
| | tracer.call_module = functools.partial( |
| | _patched_call_module, original_call_module, execution_info |
| | ) |
| | prefixed_param_name_to_param = dict(root_module.named_parameters()) |
| | tracer.create_proxy = functools.partial( |
| | _patched_create_proxy, original_create_proxy, execution_info, prefixed_param_name_to_param |
| | ) |
| | try: |
| | yield |
| | finally: |
| | tracer.call_module = original_call_module |
| | tracer.create_proxy = original_create_proxy |
| |
|