File size: 42,670 Bytes
59f1501 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 |
# mypy: ignore-errors
"""
Dictionary-related variable tracking classes for PyTorch Dynamo.
This module implements variable tracking for different types of dictionary-like objects:
- Regular Python dictionaries (dict)
- Ordered dictionaries (collections.OrderedDict)
- Default dictionaries (collections.defaultdict)
- Dictionary views (keys and values)
- Sets and frozensets (implemented internally using dictionaries)
These classes are responsible for tracking dictionary operations during graph compilation,
maintaining proper guards for dictionary mutations and key existence checks. They handle
dictionary creation, modification, key/value access, and view operations while ensuring
correct behavior in the compiled code through appropriate guard installation.
The implementation uses a special _HashableTracker wrapper to handle dictionary keys
while preserving proper aliasing semantics. Sets are implemented as dictionaries with
None values for efficiency and code reuse.
"""
import collections
import functools
import inspect
import operator
import types
from collections.abc import Hashable as py_Hashable
from typing import Optional, TYPE_CHECKING
from torch._subclasses.fake_tensor import is_fake
from .. import graph_break_hints, polyfills, variables
from ..bytecode_transformation import create_call_function, create_instruction
from ..exc import raise_observed_exception, unimplemented_v2
from ..guards import GuardBuilder, install_guard
from ..source import is_from_local_source
from ..utils import (
cmp_name_to_op_mapping,
dict_items,
dict_keys,
dict_values,
istype,
specialize_symnode,
)
from .base import ValueMutationNew, VariableTracker
from .constant import ConstantVariable
if TYPE_CHECKING:
from torch._dynamo.codegen import PyCodegen
from torch._dynamo.symbolic_convert import InstructionTranslator
# [Adding a new supported class within the keys of ConstDictVarialble]
# - Add its tracker type to is_hashable
# - (perhaps) Define how it is compared in _HashableTracker._eq_impl
def raise_args_mismatch(tx, name):
raise_observed_exception(
TypeError,
tx,
args=[ConstantVariable(f"wrong number of arguments for {name}() call")],
)
def was_instancecheck_override(obj):
return type(obj).__dict__.get("__instancecheck__", False)
def raise_unhashable(arg, tx=None):
if tx is None:
from torch._dynamo.symbolic_convert import InstructionTranslator
tx = InstructionTranslator.current_tx()
raise_observed_exception(
TypeError, tx, args=[ConstantVariable(f"unhashable type: {type(arg)}")]
)
def is_hashable(x):
# NB - performing isinstance check on a LazVT realizes the VT, accidentally
# inserting the guard. To avoid this, lazyVT `is_hashable` methods looks at
# the underlying value without realizing the VT. Consider updating the
# lazyVT `is_hashable` method if you see unnecessary guarding for a key VT.
if (
isinstance(x, variables.LazyVariableTracker)
and not x.is_realized()
and x.is_hashable()
):
return True
if isinstance(x, variables.TensorVariable):
# Tensors are hashable if they have an example_value (a fake tensor)
# Most VT's should have one.
# It'd be nice if at some point we could assert that they all have one
return x.as_proxy().node.meta.get("example_value") is not None
elif isinstance(x, variables.TupleVariable):
return all(is_hashable(e) for e in x.items)
elif (
isinstance(x, variables.UserDefinedObjectVariable)
and not was_instancecheck_override(x.value)
and inspect.getattr_static(x.value, "__hash__") is int.__hash__
and isinstance(x.value, int)
):
return isinstance(x.value, py_Hashable)
else:
return isinstance(
x,
(
variables.BuiltinVariable,
variables.SymNodeVariable,
variables.ConstantVariable,
variables.EnumVariable,
variables.UserDefinedClassVariable,
variables.UserFunctionVariable,
variables.SkipFunctionVariable,
variables.misc.NumpyVariable,
variables.NNModuleVariable,
variables.UnspecializedNNModuleVariable,
variables.MethodWrapperVariable,
variables.TorchInGraphFunctionVariable,
variables.TypingVariable,
variables.FunctoolsPartialVariable,
variables.WeakRefVariable,
),
)
class ConstDictVariable(VariableTracker):
_nonvar_fields = {
"user_cls",
*VariableTracker._nonvar_fields,
}
class _HashableTracker:
"""
Auxiliary opaque internal class that wraps a VariableTracker and makes it hashable
This should not be seen or touched by anything outside of ConstDictVariable and its children
Note that it's also fine to put VTs into dictionaries and sets, but doing so does not take into account aliasing
"""
def __init__(self, vt) -> None:
# We specialize SymNodes
vt = specialize_symnode(vt)
# TODO Temporarily remove to figure out what keys are we breaking on
# and add proper support for them
if not is_hashable(vt):
raise_unhashable(vt)
self.vt = vt
@property
def underlying_value(self):
if (
isinstance(self.vt, variables.LazyVariableTracker)
and not self.vt.is_realized()
and self.vt.is_hashable()
):
return self.vt.original_value()
if isinstance(self.vt, variables.TensorVariable):
x = self.vt.as_proxy().node.meta["example_value"]
elif isinstance(self.vt, variables.TupleVariable):
Hashable = ConstDictVariable._HashableTracker
x = tuple(Hashable(e).underlying_value for e in self.vt.items)
elif isinstance(self.vt, variables.NNModuleVariable):
return self.vt.value
elif isinstance(self.vt, variables.UnspecializedNNModuleVariable):
return self.vt.value
elif isinstance(self.vt, variables.UserFunctionVariable):
return self.vt.get_function()
elif isinstance(self.vt, variables.WeakRefVariable):
# Access the underlying value inside the referent_vt for the key representation
Hashable = ConstDictVariable._HashableTracker
return Hashable(self.vt.referent_vt).underlying_value
elif isinstance(self.vt, variables.UserDefinedObjectVariable):
# The re module in Python 3.13+ has a dictionary (_cache2) with
# an object as key (`class _ZeroSentinel(int): ...`):
# python test/dynamo/test_unittest.py CPythonTestLongMessage.test_baseAssertEqual
return self.vt.value
else:
x = self.vt.as_python_constant()
return x
def __hash__(self):
return hash(self.underlying_value)
@staticmethod
def _eq_impl(a, b):
# TODO: Put this in utils and share it between variables/builtin.py and here
if type(a) != type(b):
return False
elif isinstance(a, tuple):
Hashable = ConstDictVariable._HashableTracker
return len(a) == len(b) and all(
Hashable._eq_impl(u, v) for u, v in zip(a, b)
)
elif is_fake(a):
return a is b
else:
return a == b
def __eq__(self, other: "ConstDictVariable._HashableTracker") -> bool:
Hashable = ConstDictVariable._HashableTracker
assert isinstance(other, Hashable) or ConstantVariable.is_literal(other), (
type(other)
)
if isinstance(other, Hashable):
return Hashable._eq_impl(self.underlying_value, other.underlying_value)
# constant
return Hashable._eq_impl(self.underlying_value, other)
def __init__(
self,
items: dict[VariableTracker, VariableTracker],
user_cls=dict,
**kwargs,
) -> None:
# .clone() pass these arguments in kwargs but they're recreated a few
# lines below
if "original_items" in kwargs:
kwargs.pop("original_items")
if "should_reconstruct_all" in kwargs:
kwargs.pop("should_reconstruct_all")
super().__init__(**kwargs)
Hashable = ConstDictVariable._HashableTracker
# Keys will just be HashableTrackers when cloning, in any other case they'll be VariableTrackers
assert all(
isinstance(x, (VariableTracker, Hashable))
and isinstance(v, VariableTracker)
for x, v in items.items()
)
def make_hashable(key):
return key if isinstance(key, Hashable) else Hashable(key)
self.items = {make_hashable(x): v for x, v in items.items()}
# need to reconstruct everything if the dictionary is an intermediate value
# or if a pop/delitem was executed
self.should_reconstruct_all = not is_from_local_source(self.source)
self.original_items = items.copy()
self.user_cls = user_cls
def as_proxy(self):
return {k.vt.as_proxy(): v.as_proxy() for k, v in self.items.items()}
def debug_repr(self):
return (
"{"
+ ", ".join(
f"{k.vt.debug_repr()}: {v.debug_repr()}" for k, v in self.items.items()
)
+ "}"
)
def as_python_constant(self):
return {
k.vt.as_python_constant(): v.as_python_constant()
for k, v in self.items.items()
}
def keys_as_python_constant(self):
self.install_dict_keys_match_guard()
return {k.vt.as_python_constant(): v for k, v in self.items.items()}
def python_type(self):
return self.user_cls
def __contains__(self, vt) -> bool:
assert isinstance(vt, VariableTracker)
Hashable = ConstDictVariable._HashableTracker
return (
is_hashable(vt)
and Hashable(vt) in self.items
and not isinstance(self.items[Hashable(vt)], variables.DeletedVariable)
)
def len(self):
return len(
[
x
for x in self.items.values()
if not isinstance(x, variables.DeletedVariable)
]
)
def has_new_items(self):
if self.should_reconstruct_all:
return True
return any(
self.is_new_item(self.original_items.get(key.vt), value)
for key, value in self.items.items()
)
def is_new_item(self, value, other):
# compare the id of the realized values if both values are not lazy VTs
if value and value.is_realized() and other.is_realized():
return id(value.realize()) != id(other.realize())
return id(value) != id(other)
def reconstruct_kvs_into_new_dict(self, codegen):
# Build a dictionary that contains the keys and values.
num_args = 0
for key, value in self.items.items():
# We can safely call realize() here as it won't introduce any new guards
item = self.original_items.get(key.vt)
if self.is_new_item(item, value) or self.should_reconstruct_all:
codegen(key.vt)
codegen(value)
num_args += 1
codegen.append_output(create_instruction("BUILD_MAP", arg=num_args))
def reconstruct(self, codegen: "PyCodegen"):
if self.user_cls is collections.OrderedDict:
# emit `OrderedDict(constructed_dict)`
codegen.add_push_null(
lambda: codegen.extend_output(
[
codegen.create_load_python_module(collections),
codegen.create_load_attr("OrderedDict"),
]
)
)
self.reconstruct_kvs_into_new_dict(codegen)
codegen.extend_output(create_call_function(1, False))
else:
self.reconstruct_kvs_into_new_dict(codegen)
def getitem_const_raise_exception_if_absent(
self, tx: "InstructionTranslator", arg: VariableTracker
):
key = ConstDictVariable._HashableTracker(arg)
if key not in self.items:
raise_observed_exception(KeyError, tx)
return self.items[key]
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
key = ConstDictVariable._HashableTracker(arg)
if key not in self.items:
msg = f"Dictionary key {arg.value} not found during tracing"
unimplemented_v2(
gb_type="key not found in dict",
context=f"Key {arg.value}",
explanation=msg,
hints=[
"Check if the key exists in the dictionary before accessing it.",
*graph_break_hints.USER_ERROR,
],
)
return self.items[key]
def maybe_getitem_const(self, arg: VariableTracker):
key = ConstDictVariable._HashableTracker(arg)
if key not in self.items:
return None
return self.items[key]
def realize_key_vt(self, arg: VariableTracker):
# Realize the LazyVT on a particular index
assert arg in self
key = ConstDictVariable._HashableTracker(arg)
index = tuple(self.items.keys()).index(key)
original_key_vt = tuple(self.original_items.keys())[index]
if isinstance(original_key_vt, variables.LazyVariableTracker):
original_key_vt.realize()
def install_dict_keys_match_guard(self):
if self.source:
install_guard(self.make_guard(GuardBuilder.DICT_KEYS_MATCH))
def install_dict_contains_guard(self, tx, args):
# Key guarding - These are the cases to consider
# 1) The dict has been mutated. In this case, we would have already
# inserted a DICT_KEYS_MATCH guard, so we can skip.
#
# 2) args[0].source is None. This happens for const keys. Here, we
# have to insert the DICT_CONTAINS guard.
#
# 3) args[0].source is not None. This can happen for non-const VTs.
# 3a) contains=True. In this case, we can access the lazyVT from
# original_items and selectively realize it.
# 3b) contains=False. There is no easy way to selectively apply this
# DICT_NOT_CONTAINS guard because our guard are represented via trees.
# Be conservative and add DICT_KEYS_MATCH guard.
from . import ConstantVariable
if not self.source:
return
if tx.output.side_effects.is_modified(self):
return
contains = args[0] in self
if args[0].source is None and isinstance(args[0], ConstantVariable):
install_guard(
self.make_guard(
functools.partial(
GuardBuilder.DICT_CONTAINS,
key=args[0].value,
invert=not contains,
)
)
)
elif args[0].source:
if contains:
self.realize_key_vt(args[0])
else:
self.install_dict_keys_match_guard()
def call_method(
self,
tx,
name,
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
# NB - Both key and value are LazyVariableTrackers in the beginning. So,
# we have to insert guards when a dict method is accessed. For this to
# be simple, we are conservative and overguard. We skip guard only for
# get/__getitem__ because the key guard will be inserted by the
# corresponding value VT. For __contains__, we add a DICT_CONTAINS
# guard. But for all the other methods, we insert the DICT_KEYS_MATCH
# guard to be conservative.
from . import BuiltinVariable, ConstantVariable
Hashable = ConstDictVariable._HashableTracker
arg_hashable = args and is_hashable(args[0])
if name == "__init__":
temp_dict_vt = variables.BuiltinVariable(dict).call_dict(
tx, *args, **kwargs
)
tx.output.side_effects.mutation(self)
self.items.update(temp_dict_vt.items)
return ConstantVariable.create(None)
elif name == "__getitem__":
# Key guarding - Nothing to do. LazyVT for value will take care.
assert len(args) == 1
return self.getitem_const_raise_exception_if_absent(tx, args[0])
elif name == "items":
assert not (args or kwargs)
self.install_dict_keys_match_guard()
if self.source:
tx.output.guard_on_key_order.add(self.source)
return DictItemsVariable(self)
elif name == "keys":
self.install_dict_keys_match_guard()
if self.source:
tx.output.guard_on_key_order.add(self.source)
assert not (args or kwargs)
return DictKeysVariable(self)
elif name == "values":
self.install_dict_keys_match_guard()
if self.source:
tx.output.guard_on_key_order.add(self.source)
assert not (args or kwargs)
return DictValuesVariable(self)
elif name == "copy":
self.install_dict_keys_match_guard()
assert not (args or kwargs)
return self.clone(
items=self.items.copy(), mutation_type=ValueMutationNew(), source=None
)
elif name == "__len__":
assert not (args or kwargs)
self.install_dict_keys_match_guard()
return ConstantVariable.create(len(self.items))
elif name == "__setitem__" and self.is_mutable():
if not arg_hashable:
raise_unhashable(args[0])
self.install_dict_keys_match_guard()
assert not kwargs and len(args) == 2
tx.output.side_effects.mutation(self)
self.items[Hashable(args[0])] = args[1]
return ConstantVariable.create(None)
elif name == "__delitem__" and arg_hashable and self.is_mutable():
self.install_dict_keys_match_guard()
self.should_reconstruct_all = True
tx.output.side_effects.mutation(self)
self.items.__delitem__(Hashable(args[0]))
return ConstantVariable.create(None)
elif name in ("pop", "get") and len(args) in (1, 2) and args[0] not in self:
# missing item, return the default value. Install no DICT_CONTAINS guard.
self.install_dict_contains_guard(tx, args)
if len(args) == 1:
if name == "pop":
raise_observed_exception(KeyError, tx)
return ConstantVariable(None)
else:
return args[1]
elif name == "pop" and arg_hashable and self.is_mutable():
self.should_reconstruct_all = True
tx.output.side_effects.mutation(self)
return self.items.pop(Hashable(args[0]))
elif name == "clear":
self.should_reconstruct_all = True
tx.output.side_effects.mutation(self)
self.items.clear()
return ConstantVariable.create(None)
elif name == "update" and self.is_mutable():
# In general, this call looks like `a.update(b, x=1, y=2, ...)`.
# Either `b` or the kwargs is omittable, but not both.
self.install_dict_keys_match_guard()
has_arg = len(args) == 1
has_kwargs = len(kwargs) > 0
if has_arg or has_kwargs:
tx.output.side_effects.mutation(self)
if has_arg:
if isinstance(args[0], ConstDictVariable):
# NB - Guard on all the keys of the other dict to ensure
# correctness.
args[0].install_dict_keys_match_guard()
dict_vt = args[0]
else:
dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0])
self.items.update(dict_vt.items)
if has_kwargs:
# Handle kwargs
kwargs = {
Hashable(ConstantVariable.create(k)): v
for k, v in kwargs.items()
}
self.items.update(kwargs)
return ConstantVariable.create(None)
else:
return super().call_method(tx, name, args, kwargs)
elif name in ("get", "__getattr__") and args[0] in self:
# Key guarding - Nothing to do.
return self.getitem_const(tx, args[0])
elif name == "__contains__" and len(args) == 1:
if not arg_hashable:
raise_unhashable(args[0])
self.install_dict_contains_guard(tx, args)
contains = args[0] in self
return ConstantVariable.create(contains)
elif name == "setdefault" and arg_hashable and self.is_mutable():
self.install_dict_keys_match_guard()
assert not kwargs
assert len(args) <= 2
value = self.maybe_getitem_const(args[0])
if value is not None:
return value
else:
if len(args) == 1:
x = ConstantVariable.create(None)
else:
x = args[1]
tx.output.side_effects.mutation(self)
self.items[Hashable(args[0])] = x
return x
elif name == "move_to_end":
self.install_dict_keys_match_guard()
assert not kwargs and len(args) == 1
tx.output.side_effects.mutation(self)
key = Hashable(args[0])
val = self.items[key]
self.items.pop(key)
self.items[key] = val
return ConstantVariable.create(None)
elif name == "__or__":
assert len(args) == 1
if not isinstance(args[0], ConstDictVariable):
raise TypeError(
f"unsupported operand type(s) for |: 'dict' and '{args[0].python_type().__name__}'"
)
self.install_dict_keys_match_guard()
new_dict_vt = self.clone(
items=self.items.copy(), mutation_type=ValueMutationNew(), source=None
)
# NB - Guard on all the keys of the other dict to ensure
# correctness.
args[0].install_dict_keys_match_guard()
new_dict_vt.items.update(args[0].items)
return new_dict_vt
else:
return super().call_method(tx, name, args, kwargs)
def unpack_var_sequence(self, tx):
self.install_dict_keys_match_guard()
return [x.vt for x in self.items.keys()]
def call_obj_hasattr(self, tx, name):
# dict not allow setting arbitrary attributes. To check for hasattr, we can just check the __dict__ of the dict.
# OrderedDict though requires side effects tracking because it supports arbitrary setattr.
if self.user_cls is dict:
if name in self.user_cls.__dict__:
return ConstantVariable.create(True)
return ConstantVariable.create(False)
msg = f"hasattr on {self.user_cls} is not supported"
unimplemented_v2(
gb_type="unsupported hasattr operation",
context=f"Class {self.user_cls}",
explanation=msg,
hints=[
"Consider using a regular dictionary instead",
*graph_break_hints.SUPPORTABLE,
],
)
def clone(self, **kwargs):
self.install_dict_keys_match_guard()
return super().clone(**kwargs)
class MappingProxyVariable(VariableTracker):
# proxies to the original dict_vt
def __init__(self, dv_dict: ConstDictVariable, **kwargs) -> None:
super().__init__(**kwargs)
assert isinstance(dv_dict, ConstDictVariable)
self.dv_dict = dv_dict
def python_type(self):
return types.MappingProxyType
def unpack_var_sequence(self, tx):
return self.dv_dict.unpack_var_sequence(tx)
def reconstruct(self, codegen: "PyCodegen"):
# load types.MappingProxyType
if self.source:
msg = (
f"Preexisting MappingProxyVariable (source: {self.source}) cannot be reconstructed "
"because the connection to the original dict will be lost."
)
unimplemented_v2(
gb_type="mapping proxy cannot be reconstructed",
context=f"Source: {self.source}",
explanation=msg,
hints=[
"Use a mapping proxy constructed in the same `torch.compile` region.",
*graph_break_hints.SUPPORTABLE,
],
)
codegen.add_push_null(
lambda: codegen.extend_output(
[
codegen.create_load_python_module(types),
codegen.create_load_attr("MappingProxyType"),
]
)
)
codegen(self.dv_dict)
codegen.extend_output(create_call_function(1, False))
def call_method(
self,
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
if self.source and tx.output.side_effects.has_existing_dict_mutation():
msg = (
"A dict has been modified while we have an existing mappingproxy object. "
"A mapping proxy object, as the name suggest, proxies a mapping "
"object (usually a dict). If the original dict object mutates, it "
"is reflected in the proxy object as well. For an existing proxy "
"object, we do not know the original dict it points to. Therefore, "
"for correctness we graph break when there is dict mutation and we "
"are trying to access a proxy object."
)
unimplemented_v2(
gb_type="mapping proxy affected by dictionary mutation",
context=f"Source: {self.source}, Dict mutation detected",
explanation=msg,
hints=[
"Avoid modifying dictionaries that might be referenced by mapping proxy objects",
"Or avoid using the mapping proxy objects after modifying its underlying dictionary",
],
)
return self.dv_dict.call_method(tx, name, args, kwargs)
class NNModuleHooksDictVariable(ConstDictVariable):
# Special class to avoid adding any guards on the nn module hook ids.
def install_dict_keys_match_guard(self):
pass
def install_dict_contains_guard(self, tx, args):
pass
class DefaultDictVariable(ConstDictVariable):
def __init__(self, items, user_cls, default_factory=None, **kwargs) -> None:
super().__init__(items, user_cls, **kwargs)
assert user_cls is collections.defaultdict
self.default_factory = default_factory
def is_python_constant(self):
# Return false for unsupported defaults. This ensures that a bad handler
# path is not taken in BuiltinVariable for getitem.
if self.default_factory not in [list, tuple, dict] and not self.items:
return False
return super().is_python_constant()
def debug_repr(self):
return (
f"defaultdict({self.default_factory.debug_repr()}, {super().debug_repr()})"
)
@staticmethod
def is_supported_arg(arg):
if isinstance(arg, variables.BuiltinVariable):
return arg.fn in (list, tuple, dict, set)
else:
return isinstance(arg, variables.functions.BaseUserFunctionVariable)
def call_method(
self,
tx,
name,
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
if name == "__getitem__":
assert len(args) == 1
if args[0] in self:
return self.getitem_const(tx, args[0])
else:
if self.default_factory is None:
raise KeyError(f"{args[0]}")
else:
default_var = self.default_factory.call_function(tx, [], {})
super().call_method(
tx, "__setitem__", (args[0], default_var), kwargs
)
return default_var
else:
return super().call_method(tx, name, args, kwargs)
def reconstruct(self, codegen):
# emit `defaultdict(default_factory, new_dict)`
codegen.add_push_null(
lambda: codegen.extend_output(
[
codegen.create_load_python_module(collections),
codegen.create_load_attr("defaultdict"),
]
)
)
codegen(self.default_factory)
self.reconstruct_kvs_into_new_dict(codegen)
codegen.extend_output(create_call_function(2, False))
# TODO: Implementing this via inheritance rather than composition is a
# footgun, because self method calls in dict will route back to the set
# implementation, which is almost assuredly wrong
class SetVariable(ConstDictVariable):
"""We model a sets as dictionary with None values"""
def __init__(
self,
items: list[VariableTracker],
**kwargs,
) -> None:
items = dict.fromkeys(items, SetVariable._default_value())
super().__init__(items, **kwargs)
def debug_repr(self):
if not self.items:
return "set()"
else:
return "{" + ",".join(k.vt.debug_repr() for k in self.items.keys()) + "}"
@property
def set_items(self):
return set(self.items.keys())
@staticmethod
def _default_value():
# Variable to fill in he keys of the dictionary
return ConstantVariable.create(None)
def as_proxy(self):
return {k.vt.as_proxy() for k in self.set_items}
def python_type(self):
return set
def as_python_constant(self):
return {k.vt.as_python_constant() for k in self.set_items}
def reconstruct(self, codegen: "PyCodegen"):
codegen.foreach([x.vt for x in self.set_items])
codegen.append_output(create_instruction("BUILD_SET", arg=len(self.set_items)))
def call_method(
self,
tx,
name,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> "VariableTracker":
# We forward the calls to the dictionary model
if name == "__init__":
temp_set_vt = variables.BuiltinVariable(set).call_set(tx, *args, *kwargs)
tx.output.side_effects.mutation(self)
self.items.clear()
self.items.update(temp_set_vt.items)
return ConstantVariable.create(None)
elif name == "add":
assert not kwargs
if len(args) != 1:
raise_args_mismatch(tx, name)
name = "__setitem__"
args = (args[0], SetVariable._default_value())
elif name == "pop":
assert not kwargs
assert not args
# Choose an item at random and pop it via the Dict.pop method
try:
result = self.set_items.pop().vt
except KeyError as e:
raise_observed_exception(
KeyError, tx, args=list(map(ConstantVariable.create, e.args))
)
super().call_method(tx, name, (result,), kwargs)
return result
elif name == "isdisjoint":
assert not kwargs
assert len(args) == 1
return variables.UserFunctionVariable(
polyfills.set_isdisjoint
).call_function(tx, [self, args[0]], {})
elif name == "intersection":
assert not kwargs
return variables.UserFunctionVariable(
polyfills.set_intersection
).call_function(tx, [self, *args], {})
elif name == "intersection_update":
assert not kwargs
return variables.UserFunctionVariable(
polyfills.set_intersection_update
).call_function(tx, [self, *args], {})
elif name == "union":
assert not kwargs
return variables.UserFunctionVariable(polyfills.set_union).call_function(
tx, [self, *args], {}
)
elif name == "difference":
assert not kwargs
return variables.UserFunctionVariable(
polyfills.set_difference
).call_function(tx, [self, *args], {})
elif name == "difference_update":
assert not kwargs
return variables.UserFunctionVariable(
polyfills.set_difference_update
).call_function(tx, [self, *args], {})
elif name == "symmetric_difference":
if len(args) != 1:
raise_args_mismatch(tx, name)
assert not kwargs
return variables.UserFunctionVariable(
polyfills.set_symmetric_difference
).call_function(tx, [self, *args], {})
elif name == "symmetric_difference_update":
if len(args) != 1:
raise_args_mismatch(tx, name)
assert not kwargs
return variables.UserFunctionVariable(
polyfills.set_symmetric_difference_update
).call_function(tx, [self, *args], {})
elif name == "update" and self.is_mutable():
assert not kwargs
return variables.UserFunctionVariable(polyfills.set_update).call_function(
tx, [self, *args], {}
)
elif name == "remove":
assert not kwargs
assert len(args) == 1
if args[0] not in self:
raise_observed_exception(KeyError, tx, args=args)
return super().call_method(tx, "pop", args, kwargs)
elif name == "discard":
assert not kwargs
assert len(args) == 1
if args[0] in self:
return super().call_method(tx, "pop", args, kwargs)
else:
return ConstantVariable.create(value=None)
elif name in ("issubset", "issuperset"):
op = {
"issubset": operator.le,
"issuperset": operator.ge,
}
other = args[0].realize()
if not istype(other, SetVariable):
other = variables.BuiltinVariable(set).call_function(tx, [other], {})
return variables.BuiltinVariable(op.get(name)).call_function(
tx, [self, other], {}
)
return super().call_method(tx, name, args, kwargs)
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
raise RuntimeError("Illegal to getitem on a set")
def install_dict_keys_match_guard(self):
# Already EQUALS_MATCH guarded
pass
def install_dict_contains_guard(self, tx, args):
# Already EQUALS_MATCH guarded
pass
class FrozensetVariable(SetVariable):
def __init__(
self,
items: list[VariableTracker],
**kwargs,
) -> None:
super().__init__(items, **kwargs)
def debug_repr(self):
if not self.items:
return "frozenset()"
else:
return "{" + ",".join(k.vt.debug_repr() for k in self.items.keys()) + "}"
@property
def set_items(self):
return self.items.keys()
def python_type(self):
return frozenset
def as_python_constant(self):
return {k.vt.as_python_constant() for k in self.set_items}
def reconstruct(self, codegen: "PyCodegen"):
codegen.foreach([x.vt for x in self.set_items])
codegen.add_push_null(
lambda: codegen.extend_output(
[
codegen.create_load_global("frozenset"),
]
)
)
codegen.extend_output(create_call_function(0, False))
def call_method(
self,
tx,
name,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> "VariableTracker":
if name in ["add", "pop", "update", "remove", "discard", "clear"]:
raise RuntimeError(f"Illegal call_method {name} on a frozenset")
elif name == "__init__":
# frozenset is immutable. Calling __init__ again shouldn't have any effect
# In[1]: s = frozenset([1, 2])
#
# In[2]: s.__init__([3, 4])
#
# In[3]: s
# frozenset({1, 2})
return ConstantVariable.create(None)
return super().call_method(tx, name, args, kwargs)
class DictKeySetVariable(SetVariable):
def __init__(
self,
items: list[VariableTracker],
**kwargs,
) -> None:
super().__init__(items, **kwargs)
def debug_repr(self):
if not self.items:
return "dict_keys([])"
else:
return (
"dict_keys(["
+ ",".join(k.vt.debug_repr() for k in self.items.keys())
+ "])"
)
@property
def set_items(self):
return self.items
def python_type(self):
return dict_keys
def as_python_constant(self):
return dict.fromkeys(
{k.vt.as_python_constant() for k in self.set_items}, None
).keys()
def call_method(
self,
tx,
name,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> "VariableTracker":
if name in ["add", "pop", "update", "remove", "discard", "clear"]:
raise RuntimeError(f"Illegal call_method {name} on a dict_keys")
return super().call_method(tx, name, args, kwargs)
class DictViewVariable(VariableTracker):
"""
Models _PyDictViewObject
This is an "abstract" class. Subclasses will override kv and the items method
"""
kv: Optional[str] = None
def __init__(self, dv_dict: ConstDictVariable, **kwargs) -> None:
super().__init__(**kwargs)
assert self.kv in ("keys", "values", "items")
assert isinstance(dv_dict, ConstDictVariable)
self.dv_dict = dv_dict
@property
def view_items(self):
return getattr(self.dv_dict.items, self.kv)()
@property
def view_items_vt(self):
# Returns an iterable of the unpacked items
# Implement in the subclasses
raise NotImplementedError
def unpack_var_sequence(self, tx):
return self.view_items_vt
def reconstruct(self, codegen: "PyCodegen"):
codegen(self.dv_dict)
codegen.load_method(self.kv)
codegen.call_method(0)
def call_method(
self,
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
if name == "__len__":
return self.dv_dict.call_method(tx, name, args, kwargs)
return super().call_method(tx, name, args, kwargs)
class DictKeysVariable(DictViewVariable):
kv = "keys"
@property
def set_items(self):
return set(self.view_items)
@property
def view_items_vt(self):
# Returns an iterable of the unpacked items
return [x.vt for x in self.view_items]
def python_type(self):
return dict_keys
def call_method(
self,
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
if name == "__contains__":
return self.dv_dict.call_method(tx, name, args, kwargs)
if name in cmp_name_to_op_mapping:
if not isinstance(args[0], (SetVariable, DictKeysVariable)):
return ConstantVariable.create(NotImplemented)
return ConstantVariable.create(
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items)
)
return super().call_method(tx, name, args, kwargs)
class DictValuesVariable(DictViewVariable):
# DictValuesVariable is an iterable but cannot be compared.
kv = "values"
@property
def view_items_vt(self):
return list(self.view_items)
def python_type(self):
return dict_values
class DictItemsVariable(DictViewVariable):
kv = "items"
@property
def view_items_vt(self):
# Returns an iterable of the unpacked items
return [variables.TupleVariable([k.vt, v]) for k, v in self.view_items]
def python_type(self):
return dict_items
|