|
|
|
|
|
|
|
|
""" |
|
|
This module implements variable tracking for PyTorch optimizers during Dynamo tracing. |
|
|
|
|
|
The OptimizerVariable class provides specialized handling for optimizer instances by: |
|
|
- Optimizing the tracing of expensive optimizer initialization |
|
|
- Managing optimizer state and parameter group tracking |
|
|
- Handling tensor sources and guards for optimizer state tensors |
|
|
- Supporting CUDA graph execution through static tensor address management |
|
|
- Providing special handling for parameter gradients and optimizer state tensors |
|
|
|
|
|
Key features include: |
|
|
- Efficient initialization tracing via _init_group optimization |
|
|
- Automatic marking of optimizer state tensors as static for CUDA graphs |
|
|
- Proper source tracking for parameter groups, gradients, and state tensors |
|
|
- Guard installation for optimizer state structure |
|
|
- Support for both CPU and GPU tensor handling |
|
|
- Cleanup of static tensor references via finalizers |
|
|
|
|
|
The module integrates with Dynamo's broader tracing system while providing |
|
|
optimizer-specific optimizations and safety guarantees. |
|
|
""" |
|
|
|
|
|
import logging |
|
|
import weakref |
|
|
from typing import TYPE_CHECKING |
|
|
|
|
|
import torch |
|
|
from torch._logging import getArtifactLogger |
|
|
from torch.utils._pytree import tree_map_only |
|
|
|
|
|
from ..guards import GuardBuilder, install_guard |
|
|
from ..source import ( |
|
|
AttrSource, |
|
|
ConstDictKeySource, |
|
|
DictGetItemSource, |
|
|
GetItemSource, |
|
|
GlobalWeakRefSource, |
|
|
GradSource, |
|
|
) |
|
|
from ..utils import GLOBAL_KEY_PREFIX |
|
|
from .base import VariableTracker |
|
|
from .constant import ConstantVariable |
|
|
from .dicts import ConstDictVariable |
|
|
from .lists import ListVariable |
|
|
from .misc import GetAttrVariable |
|
|
from .user_defined import UserDefinedObjectVariable |
|
|
|
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from torch._dynamo.symbolic_convert import InstructionTranslator |
|
|
|
|
|
|
|
|
class ArgMappingException(Exception): |
|
|
pass |
|
|
|
|
|
|
|
|
class GuardInstallException(Exception): |
|
|
pass |
|
|
|
|
|
|
|
|
perf_hint_log = getArtifactLogger(__name__, "perf_hints") |
|
|
|
|
|
|
|
|
def _is_static_for_cudagraphs(x): |
|
|
from torch._inductor.cudagraph_trees import get_manager |
|
|
|
|
|
if x.is_cuda: |
|
|
manager = get_manager(x.device.index, False) |
|
|
is_static_address = torch._dynamo.utils.get_static_address_type(x) is not None |
|
|
if manager: |
|
|
return ( |
|
|
is_static_address |
|
|
or manager.current_node._is_cuda_graph_recorded_tensor(x) |
|
|
) |
|
|
else: |
|
|
return is_static_address |
|
|
else: |
|
|
|
|
|
return True |
|
|
|
|
|
|
|
|
class OptimizerVariable(UserDefinedObjectVariable): |
|
|
_nonvar_fields = { |
|
|
"grad_to_source", |
|
|
"tensor_to_source", |
|
|
"static_tensor_names", |
|
|
*UserDefinedObjectVariable._nonvar_fields, |
|
|
} |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
value, |
|
|
grad_to_source=None, |
|
|
static_tensor_names=None, |
|
|
tensor_to_source=None, |
|
|
**kwargs, |
|
|
) -> None: |
|
|
super().__init__(value, **kwargs) |
|
|
self.grad_to_source = grad_to_source or {} |
|
|
self.tensor_to_source = tensor_to_source or {} |
|
|
self.static_tensor_names = static_tensor_names or set() |
|
|
|
|
|
def call_method( |
|
|
self, |
|
|
tx, |
|
|
name, |
|
|
args: "list[VariableTracker]", |
|
|
kwargs: "dict[str, VariableTracker]", |
|
|
) -> "VariableTracker": |
|
|
"""This is an optimization to avoid tracing the very slow initialization of the optimizer""" |
|
|
if name == "_init_group": |
|
|
try: |
|
|
self.graph_break_if_pending_mutation(tx) |
|
|
self.move_step_if_cpu() |
|
|
py_args, py_kwargs = self.get_python_args(*args, **kwargs) |
|
|
ret_val = self.value._init_group(*py_args, **py_kwargs) |
|
|
self.map_sources_and_install_guards(tx) |
|
|
self.update_list_args(tx, args, kwargs, py_args, py_kwargs) |
|
|
|
|
|
|
|
|
mangled_name = f"__optimizer_{id(self.value)}" |
|
|
tx.store_global_weakref_by_id(mangled_name, self.value) |
|
|
self.create_finalizer(tx) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return ConstantVariable.create(ret_val) |
|
|
except (ArgMappingException, GuardInstallException) as _: |
|
|
|
|
|
pass |
|
|
|
|
|
return super().call_method(tx, name, args, kwargs) |
|
|
|
|
|
def var_getattr(self, tx: "InstructionTranslator", name): |
|
|
|
|
|
|
|
|
|
|
|
if name in ("_init_group", "step"): |
|
|
return GetAttrVariable(self, name, source=AttrSource(self.source, name)) |
|
|
|
|
|
if name == "param_groups": |
|
|
from ..decorators import mark_static_address |
|
|
|
|
|
for group in self.value.param_groups: |
|
|
for p in group["params"]: |
|
|
mark_static_address(p) |
|
|
|
|
|
self._set_capturable(tx) |
|
|
|
|
|
return super().var_getattr(tx, name) |
|
|
|
|
|
def graph_break_if_pending_mutation(self, tx): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for g in self.value.param_groups: |
|
|
for p in g["params"]: |
|
|
side_effects = tx.output.side_effects |
|
|
variable = side_effects.id_to_variable.get(id(p), None) |
|
|
if variable and side_effects.has_pending_mutation(variable): |
|
|
from ..exc import Unsupported |
|
|
|
|
|
raise Unsupported("Pending mutation on parameter") |
|
|
|
|
|
def _set_capturable(self, tx): |
|
|
from . import LazyVariableTracker |
|
|
|
|
|
|
|
|
|
|
|
def safe_to_set_capturable(group): |
|
|
all_uninitialized = True |
|
|
all_gpu = True |
|
|
|
|
|
for p in group.get("params", []): |
|
|
all_gpu &= p.is_cuda or p.is_xpu |
|
|
all_uninitialized &= p not in self.value.state |
|
|
|
|
|
return "capturable" in group and all_uninitialized and all_gpu |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for group in self.value.param_groups: |
|
|
if safe_to_set_capturable(group): |
|
|
group["capturable"] = True |
|
|
|
|
|
source = self.source and AttrSource(self.source, "param_groups") |
|
|
param_groups_vt = LazyVariableTracker.realize_all( |
|
|
VariableTracker.build(tx, self.value.param_groups, source) |
|
|
) |
|
|
for param_group_vt in param_groups_vt.items: |
|
|
key = ConstDictVariable._HashableTracker( |
|
|
ConstantVariable.create("capturable") |
|
|
) |
|
|
param_group_vt.items[key] = ConstantVariable.create(True) |
|
|
|
|
|
def get_python_args(self, *args, **kwargs): |
|
|
"""Get python values equivalent to the variable tracker args""" |
|
|
|
|
|
def map_arg(arg): |
|
|
if isinstance(arg, ConstantVariable): |
|
|
return arg.as_python_constant() |
|
|
elif isinstance(arg, ListVariable) and not arg.items: |
|
|
return [] |
|
|
elif ( |
|
|
isinstance(arg, ConstDictVariable) |
|
|
and isinstance(arg.source, GetItemSource) |
|
|
and isinstance(arg.source.base, AttrSource) |
|
|
and arg.source.base.member == "param_groups" |
|
|
): |
|
|
return self.value.param_groups[arg.source.index] |
|
|
|
|
|
raise ArgMappingException |
|
|
|
|
|
new_args = [map_arg(arg) for arg in args] |
|
|
new_kwargs = {k: map_arg(v) for k, v in kwargs.items()} |
|
|
|
|
|
return new_args, new_kwargs |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def move_step_if_cpu(self): |
|
|
for p, state in self.value.state.items(): |
|
|
if "step" in state and state["step"].is_cpu: |
|
|
state["step"] = state["step"].to(p.device) |
|
|
|
|
|
def map_sources_and_install_guards(self, tx): |
|
|
from ..decorators import mark_static_address |
|
|
from .lazy import LazyVariableTracker |
|
|
|
|
|
self.grad_to_source = {} |
|
|
self.tensor_to_source = {} |
|
|
|
|
|
def mark_static(x): |
|
|
mark_static_address(x) |
|
|
|
|
|
tree_map_only(torch.Tensor, mark_static, self.value.state) |
|
|
|
|
|
|
|
|
|
|
|
params_groups_source = self.source and AttrSource(self.source, "param_groups") |
|
|
param_groups_vt = LazyVariableTracker.realize_all( |
|
|
VariableTracker.build(tx, self.value.param_groups, params_groups_source) |
|
|
) |
|
|
|
|
|
state_source = self.source and AttrSource(self.source, "state") |
|
|
|
|
|
state_vt = VariableTracker.build(tx, self.value.state, state_source) |
|
|
|
|
|
|
|
|
|
|
|
state_vt.realize() |
|
|
tx.output.guard_on_key_order.add(state_source) |
|
|
|
|
|
|
|
|
|
|
|
for group, group_vt in zip(self.value.param_groups, param_groups_vt.items): |
|
|
|
|
|
|
|
|
if len(group["params"]) > 0: |
|
|
for param in group["params"]: |
|
|
if param.grad is not None: |
|
|
key_index = None |
|
|
for i, k in enumerate(self.value.state.keys()): |
|
|
if k is param: |
|
|
key_index = i |
|
|
break |
|
|
if key_index: |
|
|
LazyVariableTracker.realize_all( |
|
|
VariableTracker.build( |
|
|
tx, |
|
|
self.value.state[param], |
|
|
DictGetItemSource( |
|
|
state_source, |
|
|
ConstDictKeySource(state_source, key_index), |
|
|
), |
|
|
) |
|
|
) |
|
|
break |
|
|
|
|
|
params_vt = group_vt.getitem_const(tx, ConstantVariable.create("params")) |
|
|
all_static = True |
|
|
non_static_grads = [] |
|
|
for p_ind, (p, p_vt) in enumerate( |
|
|
zip(group["params"], params_vt.unpack_var_sequence(tx)) |
|
|
): |
|
|
param_source = p_vt.source |
|
|
self.tensor_to_source[p] = param_source |
|
|
grad_source = GradSource( |
|
|
param_source, |
|
|
"grad", |
|
|
) |
|
|
|
|
|
if p.grad is not None: |
|
|
self.grad_to_source[p.grad] = grad_source |
|
|
if not _is_static_for_cudagraphs(p.grad): |
|
|
all_static = False |
|
|
non_static_grads.append(grad_source) |
|
|
else: |
|
|
install_guard(grad_source.make_guard(GuardBuilder.CONSTANT_MATCH)) |
|
|
|
|
|
|
|
|
|
|
|
if not all_static and perf_hint_log.isEnabledFor(logging.DEBUG): |
|
|
non_static_grads = [src.name() for src in non_static_grads] |
|
|
perf_hint_log.warning( |
|
|
( |
|
|
"Grad tensors %s will be copied during cudagraphs execution." |
|
|
"If using cudagraphs and the grad tensor addresses will be the same across runs," |
|
|
" use torch._dynamo.decorators.mark_static_address to elide this copy.", |
|
|
), |
|
|
non_static_grads, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
for idx, (p, value) in enumerate(self.value.state.items()): |
|
|
p_state_source = DictGetItemSource( |
|
|
state_source, ConstDictKeySource(state_source, idx) |
|
|
) |
|
|
tx.output.guard_on_key_order.add(p_state_source) |
|
|
for inner_idx, (k, v) in enumerate(value.items()): |
|
|
if ( |
|
|
isinstance(v, torch.Tensor) |
|
|
and v not in self.grad_to_source |
|
|
and v not in self.tensor_to_source |
|
|
): |
|
|
self.tensor_to_source[v] = DictGetItemSource( |
|
|
p_state_source, ConstDictKeySource(p_state_source, inner_idx) |
|
|
) |
|
|
|
|
|
def wrap_tensor(self, tx: "InstructionTranslator", tensor_value): |
|
|
"""Wrap state tensor in a TensorVariable""" |
|
|
from ..decorators import mark_static_address |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if tensor_value in self.tensor_to_source: |
|
|
|
|
|
mark_static_address(tensor_value) |
|
|
source = self.tensor_to_source[tensor_value] |
|
|
self.static_tensor_names.add(tx.output.module_key_name(source.name())) |
|
|
elif tensor_value in self.grad_to_source: |
|
|
source = self.grad_to_source[tensor_value] |
|
|
else: |
|
|
|
|
|
mark_static_address(tensor_value) |
|
|
|
|
|
global_name = tx.store_global_weakref_by_id(GLOBAL_KEY_PREFIX, tensor_value) |
|
|
source = GlobalWeakRefSource(global_name) |
|
|
self.static_tensor_names.add(tx.output.module_key_name(source.name())) |
|
|
|
|
|
return VariableTracker.build(tx, tensor_value, source) |
|
|
|
|
|
def update_list_args( |
|
|
self, tx: "InstructionTranslator", args, kwargs, py_args, py_kwargs |
|
|
): |
|
|
"""Update the args and kwargs to the traced optimizer call""" |
|
|
for arg, py_arg in zip(args, py_args): |
|
|
if isinstance(arg, ListVariable): |
|
|
assert isinstance(py_arg, list), ( |
|
|
"py_arg should be a list in optimizer variable" |
|
|
) |
|
|
for i, val in enumerate(py_arg): |
|
|
tx.output.side_effects.mutation(arg) |
|
|
if isinstance(val, torch.Tensor): |
|
|
arg.items.append(self.wrap_tensor(tx, val)) |
|
|
else: |
|
|
source = arg.source and GetItemSource(arg.source, i) |
|
|
arg.items.append(VariableTracker.build(tx, val, source)) |
|
|
|
|
|
def create_finalizer(self, tx): |
|
|
names_to_delete = self.static_tensor_names |
|
|
value = self.value |
|
|
tc = tx.output.tracing_context |
|
|
|
|
|
def init_finalizer(gm): |
|
|
def clear_static_tensor_refs(): |
|
|
for name in names_to_delete: |
|
|
gm._buffers.pop(name, None) |
|
|
gm._parameters.pop(name, None) |
|
|
if tc.params_flat: |
|
|
tc.params_flat.clear() |
|
|
if tc.params_flat_unwrap_subclasses: |
|
|
tc.params_flat_unwrap_subclasses.clear() |
|
|
|
|
|
weakref.finalize(value, clear_static_tensor_refs) |
|
|
|
|
|
tx.output.add_graph_finalizer(init_finalizer) |
|
|
|